@@ -461,12 +461,6 @@ def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"):
461461 )
462462
463463 def test_input_mutation (self ):
464- def inner_fn (x , y ):
465- x0 = x + 1
466- y0 = y + 2
467- z = x0 .sum () + y0 .sum ()
468- return z
469-
470464 def inner_fn2 (x , y ):
471465 x0 = x + 1
472466 y0 = y + 1
@@ -476,9 +470,6 @@ def inner_fn2(x, y):
476470
477471 def fn (x , y ):
478472 x0 = torch .sin (x )
479- _y0 = torch .cos (y )
480- # o0 = inner_fn(x0, y0)
481- # o1 = inner_fn(x0, o0)
482473 o2 = inner_fn2 (x0 , y )
483474 o3 = inner_fn2 (x0 .clone (), y .clone ())
484475 return o2 + o3
@@ -985,10 +976,7 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
985976 )
986977
987978 def test_mutation_ordering (self ):
988- from torch ._dynamo .graph_deduplication import (
989- _populate_additional_deps ,
990- _stable_topological_sort ,
991- )
979+ from torch ._dynamo .graph_deduplication import _stable_topological_sort
992980
993981 def inner_fn (x , y ):
994982 x0 = x .view (x .size ())
@@ -1013,74 +1001,109 @@ def fn(x, y):
10131001 x_clone = x .clone ()
10141002 y_clone = y .clone ()
10151003
1016- graph , tracker = extract_graph_and_tracker (fn , x_clone , y_clone )
1004+ graph , _ = extract_graph_and_tracker (fn , x_clone , y_clone )
1005+
1006+ def graph_code (graph ):
1007+ return graph .python_code ("self" ).src
10171008
10181009 def get_node (name ):
10191010 return next (n for n in graph .nodes if n .name == name )
10201011
1021- additional_deps = _populate_additional_deps (
1022- graph , tracker .node_to_mutated_arg_positions
1023- )
1024-
10251012 self .assertExpectedInline (
1026- additional_deps ,
1027- """defaultdict(<class 'torch.utils._ordered_set.OrderedSet'>, {add_: OrderedSet([x0, x0_1]), invoke_subgraph: OrderedSet([add_]), invoke_subgraph_1: OrderedSet([add_, mul_]), mul_: OrderedSet([invoke_subgraph])})""" ,
1013+ graph_code (graph ),
1014+ """\
1015+
1016+
1017+
1018+ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
1019+ subgraph_0 = self.subgraph_0
1020+ l_x_ = L_x_
1021+ l_y_ = L_y_
1022+ x0 = l_x_.view((10, 10))
1023+ o0 = x0.view((10, 10)); x0 = None
1024+ x0_1 = l_x_.view((10, 10))
1025+ o1 = x0_1.view((10, 10)); x0_1 = None
1026+ add_ = l_x_.add_(l_x_); add_ = None
1027+ add_2 = o0 + o1; o0 = o1 = None
1028+ invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
1029+ mul_ = l_y_.mul_(l_y_); mul_ = None
1030+ getitem = invoke_subgraph[0]; invoke_subgraph = None
1031+ sum_5 = getitem.sum(); getitem = None
1032+ add_3 = add_2 + sum_5; add_2 = sum_5 = None
1033+ invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
1034+ getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
1035+ sum_6 = getitem_1.sum(); getitem_1 = None
1036+ add_4 = add_3 + sum_6; add_3 = sum_6 = None
1037+ return (add_4,)
1038+ """ ,
10281039 )
10291040
1041+ # Shuffle nodes in the graph
10301042 add_ = get_node ("add_" )
10311043 mul_ = get_node ("mul_" )
1032- x0 = get_node ("x0" )
1033- x0 .append (mul_ )
10341044 o1 = get_node ("o1" )
1035- o1 .append (add_ )
1045+ o1 .append (mul_ )
1046+ add_2 = get_node ("add_2" )
1047+ add_2 .append (add_ )
1048+
10361049 self .assertExpectedInline (
1037- graph ,
1050+ graph_code ( graph ) ,
10381051 """\
1039- graph():
1040- %subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
1041- %l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
1042- %l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
1043- %x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1044- %mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
1045- %o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
1046- %x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1047- %o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
1048- %add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
1049- %add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
1050- %invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1051- %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
1052- %sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
1053- %add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
1054- %invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1055- %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
1056- %sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
1057- %add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
1058- return (add_4,)""" ,
1052+
1053+
1054+
1055+ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
1056+ subgraph_0 = self.subgraph_0
1057+ l_x_ = L_x_
1058+ l_y_ = L_y_
1059+ x0 = l_x_.view((10, 10))
1060+ o0 = x0.view((10, 10)); x0 = None
1061+ x0_1 = l_x_.view((10, 10))
1062+ o1 = x0_1.view((10, 10)); x0_1 = None
1063+ mul_ = l_y_.mul_(l_y_); mul_ = None
1064+ add_2 = o0 + o1; o0 = o1 = None
1065+ add_ = l_x_.add_(l_x_); add_ = None
1066+ invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
1067+ getitem = invoke_subgraph[0]; invoke_subgraph = None
1068+ sum_5 = getitem.sum(); getitem = None
1069+ add_3 = add_2 + sum_5; add_2 = sum_5 = None
1070+ invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
1071+ getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
1072+ sum_6 = getitem_1.sum(); getitem_1 = None
1073+ add_4 = add_3 + sum_6; add_3 = sum_6 = None
1074+ return (add_4,)
1075+ """ ,
1076+ )
1077+ _stable_topological_sort (
1078+ graph , torch ._dynamo .graph_deduplication .last_node_to_additional_deps
10591079 )
1060- _stable_topological_sort (graph , additional_deps )
10611080 self .assertExpectedInline (
1062- graph ,
1081+ graph_code ( graph ) ,
10631082 """\
1064- graph():
1065- %subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
1066- %l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
1067- %l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
1068- %x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1069- %o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
1070- %x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1071- %o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
1072- %add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
1073- %add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
1074- %invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1075- %mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
1076- %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
1077- %sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
1078- %add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
1079- %invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1080- %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
1081- %sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
1082- %add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
1083- return (add_4,)""" ,
1083+
1084+
1085+
1086+ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
1087+ subgraph_0 = self.subgraph_0
1088+ l_x_ = L_x_
1089+ l_y_ = L_y_
1090+ x0 = l_x_.view((10, 10))
1091+ o0 = x0.view((10, 10)); x0 = None
1092+ x0_1 = l_x_.view((10, 10))
1093+ o1 = x0_1.view((10, 10)); x0_1 = None
1094+ add_2 = o0 + o1; o0 = o1 = None
1095+ add_ = l_x_.add_(l_x_); add_ = None
1096+ invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
1097+ mul_ = l_y_.mul_(l_y_); mul_ = None
1098+ getitem = invoke_subgraph[0]; invoke_subgraph = None
1099+ sum_5 = getitem.sum(); getitem = None
1100+ add_3 = add_2 + sum_5; add_2 = sum_5 = None
1101+ invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
1102+ getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
1103+ sum_6 = getitem_1.sum(); getitem_1 = None
1104+ add_4 = add_3 + sum_6; add_3 = sum_6 = None
1105+ return (add_4,)
1106+ """ ,
10841107 )
10851108
10861109
0 commit comments