Skip to content

Commit ff039d3

Browse files
mlazospytorchmergebot
authored andcommitted
[Dynamo] Optimize dedupe region ancestor tracking (pytorch#152589)
Pull Request resolved: pytorch#152589 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#152389, pytorch#152505, pytorch#152410, pytorch#152506, pytorch#152570, pytorch#152572
1 parent d0faa99 commit ff039d3

File tree

4 files changed

+103
-86
lines changed

4 files changed

+103
-86
lines changed
Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
add_loop_eager,compile_time_instruction_count,3167000000,0.015
1+
add_loop_eager,compile_time_instruction_count,3035000000,0.015
22

33

44

5-
add_loop_eager_dynamic,compile_time_instruction_count,6066000000,0.025
5+
add_loop_eager_dynamic,compile_time_instruction_count,5928000000,0.025
66

77

88

@@ -14,11 +14,11 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44480000000,0.025
1414

1515

1616

17-
add_loop_inductor_gpu,compile_time_instruction_count,26050000000,0.015
17+
add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
1818

1919

2020

21-
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1018000000,0.015
21+
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1011000000,0.015
2222

2323

2424

@@ -34,44 +34,44 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
3434

3535

3636

37-
update_hint_regression,compile_time_instruction_count,1723000000,0.02
37+
update_hint_regression,compile_time_instruction_count,1715000000,0.02
3838

3939

4040

4141
float_args,compile_time_instruction_count,439200000,0.015
4242

4343

4444

45-
sum_floordiv_regression,compile_time_instruction_count,1024000000,0.015
45+
sum_floordiv_regression,compile_time_instruction_count,1009000000,0.015
4646

4747

4848

49-
symint_sum,compile_time_instruction_count,3278000000,0.015
49+
symint_sum,compile_time_instruction_count,3252000000,0.015
5050

5151

5252

53-
symint_sum_loop,compile_time_instruction_count,4300000000,0.015
53+
symint_sum_loop,compile_time_instruction_count,4262000000,0.015
5454

5555

5656

5757
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015
5858

5959

6060

61-
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5944000000,0.015
61+
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
6262

6363

6464

65-
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8586000000,0.015
65+
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8630000000,0.015
6666

6767

6868

6969
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
7070

7171

7272

73-
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3795000000,0.015
73+
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
7474

7575

7676

77-
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10280000000,0.015
77+
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015

test/dynamo/test_graph_region_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ def inner_fn(x, y):
6161
return z
6262

6363
def fn(x, y):
64-
_o0 = inner_fn(x, y)
64+
o0 = inner_fn(x, y)
6565
o1 = torch.sin(y)
6666
o2 = inner_fn(x, o1)
6767
o3 = inner_fn(x, y)
6868
o4 = o3 * o3
69-
return o2 * o4
69+
return o2 * o4 + o0
7070

7171
self.assertExpectedInline(
7272
self.get_result(

torch/_dynamo/graph_region_tracker.py

Lines changed: 75 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def __init__(self, origin: Node) -> None:
151151
def create(origin: Node) -> "BackwardBfsArgIter":
152152
it = BackwardBfsArgIter(origin)
153153
it.add_children(origin)
154+
# pop the origin node, since it is the origin of
155+
# the region and does not need to be considered for addition
156+
assert it.next()
154157
return it
155158

156159
def next(self) -> Optional[Node]:
@@ -165,17 +168,11 @@ def peek(self) -> Optional[Node]:
165168
return self._cur
166169

167170
def add_children(self, node: Node) -> None:
168-
arg: Any
169-
flat_args, _ = tree_flatten(node.args)
171+
flat_args = _get_flat_args_unique(node, {})
170172
for arg in flat_args:
171173
if isinstance(arg, Node):
172174
self._append(arg)
173175

174-
flat_kwargs, _ = tree_flatten(node.kwargs)
175-
for kwarg in flat_kwargs:
176-
if isinstance(kwarg, Node):
177-
self._append(kwarg)
178-
179176
def _append(self, arg: Node) -> None:
180177
if self._cur is None:
181178
self._cur = arg
@@ -328,6 +325,38 @@ def __str__(self) -> str:
328325
return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})"
329326

330327

328+
class RegionWrapper:
329+
"""Holds state for regions e.g. ancestors and new candidate nodes for consideration"""
330+
331+
def __init__(
332+
self, region: Region, node_to_recursive_ancestors: dict[Node, set[Node]]
333+
) -> None:
334+
assert len(region) == 1, "all regions should start with one node"
335+
node = region[0]
336+
self.node_to_recursive_ancestors = node_to_recursive_ancestors
337+
self.iter = BackwardBfsArgIter.create(node)
338+
self.nodes_unique = OrderedSet([node])
339+
self.ancestors = set(node_to_recursive_ancestors[node])
340+
self.region = region
341+
342+
def next_candidate(self) -> Optional[Node]:
343+
return self.iter.next()
344+
345+
def will_inclusion_create_cycle(self, node: Node) -> bool:
346+
external_users = [user for user in node.users if user not in self.nodes_unique]
347+
for user in external_users:
348+
if user in self.ancestors:
349+
return True
350+
351+
return False
352+
353+
def add(self, node: Node) -> None:
354+
self.nodes_unique.add(node)
355+
self.region.append(node)
356+
self.iter.add_children(node)
357+
self.ancestors.update(self.node_to_recursive_ancestors[node])
358+
359+
331360
def fully_expand_region_group(
332361
regions: list[Region],
333362
seen_nodes: set[Node],
@@ -339,20 +368,12 @@ def fully_expand_region_group(
339368

340369
# All regions should start with 1 node
341370
assert all(len(region) == 1 for region in regions)
342-
region_iters = []
343-
for region in regions:
344-
(origin,) = region # Only works for 1 element sets
345-
region_iters.append(BackwardBfsArgIter.create(origin))
346-
347-
nodes_to_add: list[Node] = []
348-
349-
# we already have the origin node in each region
350-
for region_it in region_iters:
351-
node = region_it.next()
352-
assert node
353-
region_it.add_children(node)
371+
region_wrappers = [
372+
RegionWrapper(region, node_to_recursive_ancestors) for region in regions
373+
]
354374

355-
current_node = region_iters[0].next()
375+
nodes_to_add = OrderedSet[Node]()
376+
current_node = region_wrappers[0].next_candidate()
356377

357378
# No children
358379
if current_node is None:
@@ -362,46 +383,51 @@ def fully_expand_region_group(
362383
# regions are only expanded if the node to add is valid
363384
# for ALL regions
364385
while current_node:
365-
add_node = not _will_create_cycle(
366-
current_node, regions[0], node_to_recursive_ancestors
386+
add_to_all_regions = not region_wrappers[0].will_inclusion_create_cycle(
387+
current_node
367388
)
368389
nodes_to_add.clear()
369-
nodes_to_add.append(current_node)
370-
nodes_to_add_set = set(nodes_to_add)
371-
for ind, region_it in enumerate(region_iters[1:]):
372-
ind += 1 # compensate for the 0th region
373-
node = region_it.next()
390+
nodes_to_add.add(current_node)
391+
for region_wrapper in region_wrappers[1:]:
392+
candidate = region_wrapper.next_candidate()
374393

375394
debug_log("--------------------")
376-
debug_log("considering adding: %s, cur_node: %s", node, current_node)
377-
debug_log("previously claimed nodes: %s", node in seen_nodes)
378-
if node:
379-
debug_log("is_identical: %s", is_identical_fn(node, current_node))
380-
add_node &= (
381-
node not in seen_nodes
382-
and node not in nodes_to_add_set
383-
and node.op != "placeholder"
384-
and is_identical_fn(node, current_node)
385-
and not _will_create_cycle(
386-
node, regions[ind], node_to_recursive_ancestors
387-
)
388-
)
389-
nodes_to_add.append(node)
390-
nodes_to_add_set.add(node)
391-
else:
392-
add_node = False
395+
debug_log(
396+
"considering candidate: %s, cur_node: %s", candidate, current_node
397+
)
398+
399+
if not candidate or not add_to_all_regions:
400+
add_to_all_regions = False
401+
continue
402+
403+
debug_log(
404+
"candidate in previously claimed nodes?: %s", candidate in seen_nodes
405+
)
406+
debug_log("is_identical: %s", is_identical_fn(candidate, current_node))
407+
408+
add_to_all_regions &= (
409+
candidate not in seen_nodes
410+
and candidate not in nodes_to_add
411+
and candidate.op != "placeholder"
412+
and is_identical_fn(candidate, current_node)
413+
and not region_wrapper.will_inclusion_create_cycle(candidate)
414+
)
415+
nodes_to_add.add(candidate)
393416

417+
debug_log(f"add_to_all_regions: {add_to_all_regions}")
394418
debug_log("--------------------")
395419

396-
if add_node:
397-
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
398-
region.append(node)
420+
if add_to_all_regions:
421+
assert len(region_wrappers) == len(nodes_to_add), (
422+
"Numer of nodes to add must equal the number of regions"
423+
)
424+
for region_wrapper, node in zip(region_wrappers, nodes_to_add):
425+
region_wrapper.add(node)
399426
debug_log("adding %s's children", node)
400427
debug_log("%s %s", node.args, list(node.kwargs.items()))
401-
region_it.add_children(node)
402428
seen_nodes.add(node)
403429

404-
current_node = region_iters[0].next()
430+
current_node = region_wrappers[0].next_candidate()
405431

406432
# Ensure regions are sorted in topological order
407433
for region in regions:
@@ -424,20 +450,3 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
424450
)
425451
node_to_recursive_ancestors[node].add(arg)
426452
return node_to_recursive_ancestors
427-
428-
429-
def _will_create_cycle(
430-
node_to_add: Node,
431-
region: Region,
432-
node_to_recursive_ancestors: dict[Node, set[Node]],
433-
) -> bool:
434-
region_set: set[Node] = set(region)
435-
region_ancestors: set[Node] = set(
436-
tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0]
437-
)
438-
external_users = [user for user in node_to_add.users if user not in region_set]
439-
for user in external_users:
440-
if user in region_ancestors:
441-
return True
442-
443-
return False

torch/_dynamo/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,12 +3152,20 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
31523152
args, kwargs = get_fake_values_from_nodes(
31533153
tx, (node.args, node.kwargs), allow_non_graph_fake
31543154
)
3155-
flat_args_kwargs = get_fake_values_from_nodes(
3156-
tx, _get_flat_args(node, {}), allow_non_graph_fake
3157-
)
3158-
id_to_initial_version = {
3159-
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
3160-
}
3155+
3156+
if (
3157+
torch._dynamo.config.use_graph_deduplication
3158+
or torch._dynamo.config.track_nodes_for_deduplication
3159+
):
3160+
flat_args_kwargs = get_fake_values_from_nodes(
3161+
tx, _get_flat_args(node, {}), allow_non_graph_fake
3162+
)
3163+
id_to_initial_version = {
3164+
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
3165+
}
3166+
else:
3167+
flat_args_kwargs = []
3168+
id_to_initial_version = {}
31613169

31623170
nnmodule = None
31633171
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):

0 commit comments

Comments
 (0)