Skip to content

Commit 1181920

Browse files
mlazospytorchmergebot
authored andcommitted
[Hierarchical Compile] Add mutation dependencies to topological sorting (pytorch#152410)
Pull Request resolved: pytorch#152410 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#152389, pytorch#152505
1 parent 3592cb5 commit 1181920

File tree

2 files changed

+173
-23
lines changed

2 files changed

+173
-23
lines changed

test/dynamo/test_graph_deduplication.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# Owner(s): ["module: dynamo"]
22
# flake8: noqa: B950
3+
import contextlib
4+
35
import torch
46
import torch.fx
57
from torch._dynamo.graph_utils import _detect_cycles
68
from torch._dynamo.test_case import TestCase
7-
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
9+
from torch._dynamo.testing import (
10+
AotEagerAndRecordGraphs,
11+
extract_graph_and_tracker,
12+
normalize_gm,
13+
)
814

915

1016
def extract_graph(fn, *args, **kwargs):
@@ -18,9 +24,19 @@ def graph_str(gm):
1824

1925

2026
class GraphDededuplicationTests(TestCase):
27+
def setUp(self):
28+
self.exit_stack = contextlib.ExitStack()
29+
self.exit_stack.enter_context(
30+
torch._dynamo.config.patch("use_graph_deduplication", True)
31+
)
32+
super().setUp()
33+
34+
def tearDown(self):
35+
self.exit_stack.close()
36+
super().tearDown()
37+
2138
def run_and_return_graphs(self, fn, *args, **kwargs):
22-
with torch._dynamo.config.patch("use_graph_deduplication", True):
23-
return extract_graph(fn, *args, **kwargs)
39+
return extract_graph(fn, *args, **kwargs)
2440

2541
def test_single_subgraph(self):
2642
def inner_fn(x, y):
@@ -691,7 +707,7 @@ def get_node(name):
691707
sum_2 = get_node("sum_2")
692708
exit_autocast = mod.graph.call_function(torch.amp._exit_autocast)
693709
sum_2.append(exit_autocast)
694-
additional_deps = _populate_additional_deps(mod.graph)
710+
additional_deps = _populate_additional_deps(mod.graph, {})
695711
invoke_subgraph = get_node("invoke_subgraph")
696712
invoke_subgraph.append(enter_autocast)
697713
getitem_1 = get_node("getitem_1")
@@ -906,6 +922,105 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
906922
""",
907923
)
908924

925+
def test_mutation_ordering(self):
926+
from torch._dynamo.graph_deduplication import (
927+
_populate_additional_deps,
928+
_stable_topological_sort,
929+
)
930+
931+
def inner_fn(x, y):
932+
x0 = x.view(x.size())
933+
return x0.view(x.size())
934+
935+
def inner_fn2(x, y):
936+
x = x * 2
937+
y = y * 2
938+
return x.sum() + y.sum()
939+
940+
def fn(x, y):
941+
o0 = inner_fn(x, y)
942+
o1 = inner_fn(x, y)
943+
x.add_(x)
944+
o2 = inner_fn2(x, y)
945+
y.mul_(y)
946+
o3 = inner_fn2(x, y)
947+
return o0 + o1 + o2.sum() + o3.sum()
948+
949+
x = torch.rand(10, 10)
950+
y = torch.rand(10, 20)
951+
x_clone = x.clone()
952+
y_clone = y.clone()
953+
954+
graph, tracker = extract_graph_and_tracker(fn, x_clone, y_clone)
955+
956+
def get_node(name):
957+
return next(n for n in graph.nodes if n.name == name)
958+
959+
additional_deps = _populate_additional_deps(
960+
graph, tracker.node_to_mutated_arg_positions
961+
)
962+
963+
self.assertExpectedInline(
964+
additional_deps,
965+
"""defaultdict(<class 'torch.utils._ordered_set.OrderedSet'>, {add_: OrderedSet([x0, x0_1]), invoke_subgraph: OrderedSet([add_]), invoke_subgraph_1: OrderedSet([add_, mul_]), mul_: OrderedSet([invoke_subgraph])})""",
966+
)
967+
968+
add_ = get_node("add_")
969+
mul_ = get_node("mul_")
970+
x0 = get_node("x0")
971+
x0.append(mul_)
972+
o1 = get_node("o1")
973+
o1.append(add_)
974+
self.assertExpectedInline(
975+
graph,
976+
"""\
977+
graph():
978+
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
979+
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
980+
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
981+
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
982+
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
983+
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
984+
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
985+
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
986+
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
987+
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
988+
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
989+
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
990+
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
991+
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
992+
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
993+
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
994+
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
995+
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
996+
return (add_4,)""",
997+
)
998+
_stable_topological_sort(graph, additional_deps)
999+
self.assertExpectedInline(
1000+
graph,
1001+
"""\
1002+
graph():
1003+
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
1004+
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
1005+
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
1006+
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1007+
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
1008+
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1009+
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
1010+
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
1011+
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
1012+
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1013+
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
1014+
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
1015+
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
1016+
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
1017+
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1018+
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
1019+
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
1020+
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
1021+
return (add_4,)""",
1022+
)
1023+
9091024

9101025
if __name__ == "__main__":
9111026
from torch._dynamo.test_case import run_tests

torch/_dynamo/graph_deduplication.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]:
5757
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
5858
output_graph.graph
5959
)
60-
node_to_additional_deps = _populate_additional_deps(output_graph.graph)
60+
node_to_additional_deps = _populate_additional_deps(
61+
output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
62+
)
6163

6264
sub_gms: dict[str, torch.fx.GraphModule] = {}
6365

@@ -107,7 +109,7 @@ def _replace_region_with_subgraph(
107109
inds_with_external_users: list[int],
108110
sub_gm: torch.fx.GraphModule,
109111
subgraph_name: str,
110-
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
112+
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
111113
) -> None:
112114
sub_args = []
113115
for node_ind, arg_ind in node_ind_arg_ind:
@@ -143,11 +145,12 @@ def _replace_region_with_subgraph(
143145
# Erase in reverse topological order
144146
for node in reversed(region):
145147
graph.erase_node(node)
146-
node_to_additional_deps.pop(node)
147-
for dep_list in node_to_additional_deps.values():
148+
node_to_additional_deps.pop(node, None)
149+
for deps in node_to_additional_deps.values():
148150
try:
149-
dep_list.remove(node)
150-
except ValueError:
151+
deps.remove(node)
152+
deps.add(invoke_subgraph_node)
153+
except KeyError:
151154
pass
152155

153156
if config.graph_deduplication_lint:
@@ -294,23 +297,29 @@ def _stable_topological_sort(
294297

295298

296299
def _populate_additional_deps(
297-
graph: torch.fx.Graph,
298-
) -> dict[torch.fx.Node, list[torch.fx.Node]]:
300+
graph: torch.fx.Graph, node_to_mutated_arg_positions: dict[Node, OrderedSet[int]]
301+
) -> dict[Node, OrderedSet[Node]]:
302+
node_to_additional_deps: dict[Node, OrderedSet[Node]] = defaultdict(OrderedSet)
303+
_add_mutation_dependencies(node_to_mutated_arg_positions, node_to_additional_deps)
304+
_add_global_state_dependencies(graph, node_to_additional_deps)
305+
return node_to_additional_deps
306+
307+
308+
def _add_global_state_dependencies(
309+
graph: torch.fx.Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
310+
) -> None:
299311
import torch.amp
300312

301-
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(
302-
list
303-
)
304313
all_nodes = list(graph.nodes)
305314

306315
# These are targets of the nodes which need to stay in the same relative place in the graph
307316
global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast}
308-
all_nodes_dep_on: list[torch.fx.Node] = []
317+
all_nodes_dep_on: list[Node] = []
309318

310319
def prev_cur_nodes(
311-
all_nodes: list[torch.fx.Node],
312-
) -> Generator[tuple[list[torch.fx.Node], torch.fx.Node]]:
313-
prev_nodes: list[torch.fx.Node] = []
320+
all_nodes: list[Node],
321+
) -> Generator[tuple[list[Node], Node], None, None]:
322+
prev_nodes: list[Node] = []
314323
next_nodes = list(reversed(all_nodes))
315324

316325
while next_nodes:
@@ -320,10 +329,36 @@ def prev_cur_nodes(
320329

321330
for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
322331
args_unique = _get_flat_args_unique(cur_node, {})
323-
additional_deps = node_to_additional_deps[cur_node]
324-
additional_deps.extend(n for n in all_nodes_dep_on if n not in args_unique)
332+
new_deps = [n for n in all_nodes_dep_on if n not in args_unique]
333+
334+
if new_deps:
335+
additional_deps = node_to_additional_deps[cur_node]
336+
additional_deps.update(new_deps)
337+
325338
if cur_node.target in global_state_targets:
326-
additional_deps.extend(n for n in prev_nodes if n not in args_unique)
339+
additional_deps = node_to_additional_deps[cur_node]
340+
additional_deps.update(n for n in prev_nodes if n not in args_unique)
327341
all_nodes_dep_on.append(cur_node)
328342

329-
return node_to_additional_deps
343+
344+
def _add_mutation_dependencies(
345+
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
346+
node_to_additional_deps: dict[Node, OrderedSet[Node]],
347+
) -> None:
348+
for node, indices in node_to_mutated_arg_positions.items():
349+
flat_args_kwargs = _get_flat_args(node, {})
350+
351+
# for all mutated args,
352+
# add dependency on usages which occur after node to ensure
353+
# node will always be ordered before them
354+
# also add node as a dependency on usages which
355+
# occur before node to ensure node is ordered after them
356+
for index in indices:
357+
mutated_arg = flat_args_kwargs[index]
358+
for user in mutated_arg.users:
359+
if user is node:
360+
continue
361+
elif user < node:
362+
node_to_additional_deps[node].add(user)
363+
elif user > node:
364+
node_to_additional_deps[user].add(node)

0 commit comments

Comments
 (0)