Skip to content

Commit c7cf3fb

Browse files
Revert "[pytree][compile] Slightly faster TreeSpec init (pytorch#168024)"
This reverts commit db1551b. Reverted pytorch#168024 on behalf of https://github.com/yangw-dev due to Internal merge fail, These changes have conflicts when merging with master branch. Rebase this diff. please rebase the pr and try merge again ([comment](pytorch#168024 (comment)))
1 parent 5abb7bf commit c7cf3fb

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

torch/_dynamo/polyfills/pytree.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,8 @@ def __post_init__(self, /) -> None:
201201
num_children = 0
202202
else:
203203
assert callable(self._unflatten_func)
204-
num_nodes = 1
205-
num_leaves = 0
206-
for child in self._children:
207-
num_nodes += child.num_nodes
208-
num_leaves += child.num_leaves
204+
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
205+
num_leaves = sum(spec.num_leaves for spec in self._children)
209206
num_children = len(self._children)
210207

211208
object.__setattr__(self, "num_nodes", num_nodes)

torch/utils/_pytree.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,11 +1113,8 @@ def __post_init__(self) -> None:
11131113
num_leaves = 1
11141114
num_children = 0
11151115
else:
1116-
num_nodes = 1
1117-
num_leaves = 0
1118-
for child in self._children:
1119-
num_nodes += child.num_nodes
1120-
num_leaves += child.num_leaves
1116+
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
1117+
num_leaves = sum(spec.num_leaves for spec in self._children)
11211118
num_children = len(self._children)
11221119
object.__setattr__(self, "num_nodes", num_nodes)
11231120
object.__setattr__(self, "num_leaves", num_leaves)

0 commit comments

Comments
 (0)