11# Owner(s): ["module: dynamo"]
22# flake8: noqa: B950
3+ import contextlib
4+
35import torch
46import torch .fx
57from torch ._dynamo .graph_utils import _detect_cycles
68from torch ._dynamo .test_case import TestCase
7- from torch ._dynamo .testing import AotEagerAndRecordGraphs , normalize_gm
9+ from torch ._dynamo .testing import (
10+ AotEagerAndRecordGraphs ,
11+ extract_graph_and_tracker ,
12+ normalize_gm ,
13+ )
814
915
1016def extract_graph (fn , * args , ** kwargs ):
@@ -18,9 +24,19 @@ def graph_str(gm):
1824
1925
2026class GraphDededuplicationTests (TestCase ):
27+ def setUp (self ):
28+ self .exit_stack = contextlib .ExitStack ()
29+ self .exit_stack .enter_context (
30+ torch ._dynamo .config .patch ("use_graph_deduplication" , True )
31+ )
32+ super ().setUp ()
33+
34+ def tearDown (self ):
35+ self .exit_stack .close ()
36+ super ().tearDown ()
37+
2138 def run_and_return_graphs (self , fn , * args , ** kwargs ):
22- with torch ._dynamo .config .patch ("use_graph_deduplication" , True ):
23- return extract_graph (fn , * args , ** kwargs )
39+ return extract_graph (fn , * args , ** kwargs )
2440
2541 def test_single_subgraph (self ):
2642 def inner_fn (x , y ):
@@ -691,7 +707,7 @@ def get_node(name):
691707 sum_2 = get_node ("sum_2" )
692708 exit_autocast = mod .graph .call_function (torch .amp ._exit_autocast )
693709 sum_2 .append (exit_autocast )
694- additional_deps = _populate_additional_deps (mod .graph )
710+ additional_deps = _populate_additional_deps (mod .graph , {} )
695711 invoke_subgraph = get_node ("invoke_subgraph" )
696712 invoke_subgraph .append (enter_autocast )
697713 getitem_1 = get_node ("getitem_1" )
@@ -906,6 +922,105 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
906922""" ,
907923 )
908924
925+ def test_mutation_ordering (self ):
926+ from torch ._dynamo .graph_deduplication import (
927+ _populate_additional_deps ,
928+ _stable_topological_sort ,
929+ )
930+
931+ def inner_fn (x , y ):
932+ x0 = x .view (x .size ())
933+ return x0 .view (x .size ())
934+
935+ def inner_fn2 (x , y ):
936+ x = x * 2
937+ y = y * 2
938+ return x .sum () + y .sum ()
939+
940+ def fn (x , y ):
941+ o0 = inner_fn (x , y )
942+ o1 = inner_fn (x , y )
943+ x .add_ (x )
944+ o2 = inner_fn2 (x , y )
945+ y .mul_ (y )
946+ o3 = inner_fn2 (x , y )
947+ return o0 + o1 + o2 .sum () + o3 .sum ()
948+
949+ x = torch .rand (10 , 10 )
950+ y = torch .rand (10 , 20 )
951+ x_clone = x .clone ()
952+ y_clone = y .clone ()
953+
954+ graph , tracker = extract_graph_and_tracker (fn , x_clone , y_clone )
955+
956+ def get_node (name ):
957+ return next (n for n in graph .nodes if n .name == name )
958+
959+ additional_deps = _populate_additional_deps (
960+ graph , tracker .node_to_mutated_arg_positions
961+ )
962+
963+ self .assertExpectedInline (
964+ additional_deps ,
965+ """defaultdict(<class 'torch.utils._ordered_set.OrderedSet'>, {add_: OrderedSet([x0, x0_1]), invoke_subgraph: OrderedSet([add_]), invoke_subgraph_1: OrderedSet([add_, mul_]), mul_: OrderedSet([invoke_subgraph])})""" ,
966+ )
967+
968+ add_ = get_node ("add_" )
969+ mul_ = get_node ("mul_" )
970+ x0 = get_node ("x0" )
971+ x0 .append (mul_ )
972+ o1 = get_node ("o1" )
973+ o1 .append (add_ )
974+ self .assertExpectedInline (
975+ graph ,
976+ """\
977+ graph():
978+ %subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
979+ %l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
980+ %l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
981+ %x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
982+ %mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
983+ %o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
984+ %x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
985+ %o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
986+ %add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
987+ %add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
988+ %invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
989+ %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
990+ %sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
991+ %add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
992+ %invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
993+ %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
994+ %sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
995+ %add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
996+ return (add_4,)""" ,
997+ )
998+ _stable_topological_sort (graph , additional_deps )
999+ self .assertExpectedInline (
1000+ graph ,
1001+ """\
1002+ graph():
1003+ %subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
1004+ %l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
1005+ %l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
1006+ %x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1007+ %o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
1008+ %x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1009+ %o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
1010+ %add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
1011+ %add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
1012+ %invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1013+ %mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
1014+ %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
1015+ %sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
1016+ %add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
1017+ %invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1018+ %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
1019+ %sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
1020+ %add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
1021+ return (add_4,)""" ,
1022+ )
1023+
9091024
9101025if __name__ == "__main__" :
9111026 from torch ._dynamo .test_case import run_tests
0 commit comments