Skip to content

Commit db1551b

Browse files
anijain2305pytorchmergebot
authored andcommitted
[pytree][compile] Slightly faster TreeSpec init (pytorch#168024)
Helps with reducing Dynamo tracing time. Earlier the generator object would cause more polyfills. Pull Request resolved: pytorch#168024 Approved by: https://github.com/williamwen42
1 parent 7392106 commit db1551b

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

torch/_dynamo/polyfills/pytree.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@ def __post_init__(self, /) -> None:
201201
num_children = 0
202202
else:
203203
assert callable(self._unflatten_func)
204-
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
205-
num_leaves = sum(spec.num_leaves for spec in self._children)
204+
num_nodes = 1
205+
num_leaves = 0
206+
for child in self._children:
207+
num_nodes += child.num_nodes
208+
num_leaves += child.num_leaves
206209
num_children = len(self._children)
207210

208211
object.__setattr__(self, "num_nodes", num_nodes)

torch/utils/_pytree.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,11 @@ def __post_init__(self) -> None:
11131113
num_leaves = 1
11141114
num_children = 0
11151115
else:
1116-
num_nodes = sum((spec.num_nodes for spec in self._children), start=1)
1117-
num_leaves = sum(spec.num_leaves for spec in self._children)
1116+
num_nodes = 1
1117+
num_leaves = 0
1118+
for child in self._children:
1119+
num_nodes += child.num_nodes
1120+
num_leaves += child.num_leaves
11181121
num_children = len(self._children)
11191122
object.__setattr__(self, "num_nodes", num_nodes)
11201123
object.__setattr__(self, "num_leaves", num_leaves)

0 commit comments

Comments
 (0)