Skip to content

Commit a415c98

Browse files
mlazospytorchmergebot
authored andcommitted
[Hierarchical Compile] Replace tracing alias and mutation check with dynamo impl (pytorch#152570)
Pull Request resolved: pytorch#152570 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#152389, pytorch#152505, pytorch#152410, pytorch#152506
1 parent 57dafb9 commit a415c98

File tree

2 files changed

+210
-107
lines changed

2 files changed

+210
-107
lines changed

test/dynamo/test_graph_deduplication.py

Lines changed: 89 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,6 @@ def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"):
461461
)
462462

463463
def test_input_mutation(self):
464-
def inner_fn(x, y):
465-
x0 = x + 1
466-
y0 = y + 2
467-
z = x0.sum() + y0.sum()
468-
return z
469-
470464
def inner_fn2(x, y):
471465
x0 = x + 1
472466
y0 = y + 1
@@ -476,9 +470,6 @@ def inner_fn2(x, y):
476470

477471
def fn(x, y):
478472
x0 = torch.sin(x)
479-
_y0 = torch.cos(y)
480-
# o0 = inner_fn(x0, y0)
481-
# o1 = inner_fn(x0, o0)
482473
o2 = inner_fn2(x0, y)
483474
o3 = inner_fn2(x0.clone(), y.clone())
484475
return o2 + o3
@@ -985,10 +976,7 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
985976
)
986977

987978
def test_mutation_ordering(self):
988-
from torch._dynamo.graph_deduplication import (
989-
_populate_additional_deps,
990-
_stable_topological_sort,
991-
)
979+
from torch._dynamo.graph_deduplication import _stable_topological_sort
992980

993981
def inner_fn(x, y):
994982
x0 = x.view(x.size())
@@ -1013,74 +1001,109 @@ def fn(x, y):
10131001
x_clone = x.clone()
10141002
y_clone = y.clone()
10151003

1016-
graph, tracker = extract_graph_and_tracker(fn, x_clone, y_clone)
1004+
graph, _ = extract_graph_and_tracker(fn, x_clone, y_clone)
1005+
1006+
def graph_code(graph):
1007+
return graph.python_code("self").src
10171008

10181009
def get_node(name):
10191010
return next(n for n in graph.nodes if n.name == name)
10201011

1021-
additional_deps = _populate_additional_deps(
1022-
graph, tracker.node_to_mutated_arg_positions
1023-
)
1024-
10251012
self.assertExpectedInline(
1026-
additional_deps,
1027-
"""defaultdict(<class 'torch.utils._ordered_set.OrderedSet'>, {add_: OrderedSet([x0, x0_1]), invoke_subgraph: OrderedSet([add_]), invoke_subgraph_1: OrderedSet([add_, mul_]), mul_: OrderedSet([invoke_subgraph])})""",
1013+
graph_code(graph),
1014+
"""\
1015+
1016+
1017+
1018+
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
1019+
subgraph_0 = self.subgraph_0
1020+
l_x_ = L_x_
1021+
l_y_ = L_y_
1022+
x0 = l_x_.view((10, 10))
1023+
o0 = x0.view((10, 10)); x0 = None
1024+
x0_1 = l_x_.view((10, 10))
1025+
o1 = x0_1.view((10, 10)); x0_1 = None
1026+
add_ = l_x_.add_(l_x_); add_ = None
1027+
add_2 = o0 + o1; o0 = o1 = None
1028+
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
1029+
mul_ = l_y_.mul_(l_y_); mul_ = None
1030+
getitem = invoke_subgraph[0]; invoke_subgraph = None
1031+
sum_5 = getitem.sum(); getitem = None
1032+
add_3 = add_2 + sum_5; add_2 = sum_5 = None
1033+
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
1034+
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
1035+
sum_6 = getitem_1.sum(); getitem_1 = None
1036+
add_4 = add_3 + sum_6; add_3 = sum_6 = None
1037+
return (add_4,)
1038+
""",
10281039
)
10291040

1041+
# Shuffle nodes in the graph
10301042
add_ = get_node("add_")
10311043
mul_ = get_node("mul_")
1032-
x0 = get_node("x0")
1033-
x0.append(mul_)
10341044
o1 = get_node("o1")
1035-
o1.append(add_)
1045+
o1.append(mul_)
1046+
add_2 = get_node("add_2")
1047+
add_2.append(add_)
1048+
10361049
self.assertExpectedInline(
1037-
graph,
1050+
graph_code(graph),
10381051
"""\
1039-
graph():
1040-
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
1041-
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
1042-
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
1043-
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1044-
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
1045-
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
1046-
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1047-
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
1048-
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
1049-
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
1050-
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1051-
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
1052-
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
1053-
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
1054-
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1055-
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
1056-
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
1057-
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
1058-
return (add_4,)""",
1052+
1053+
1054+
1055+
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
1056+
subgraph_0 = self.subgraph_0
1057+
l_x_ = L_x_
1058+
l_y_ = L_y_
1059+
x0 = l_x_.view((10, 10))
1060+
o0 = x0.view((10, 10)); x0 = None
1061+
x0_1 = l_x_.view((10, 10))
1062+
o1 = x0_1.view((10, 10)); x0_1 = None
1063+
mul_ = l_y_.mul_(l_y_); mul_ = None
1064+
add_2 = o0 + o1; o0 = o1 = None
1065+
add_ = l_x_.add_(l_x_); add_ = None
1066+
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
1067+
getitem = invoke_subgraph[0]; invoke_subgraph = None
1068+
sum_5 = getitem.sum(); getitem = None
1069+
add_3 = add_2 + sum_5; add_2 = sum_5 = None
1070+
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
1071+
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
1072+
sum_6 = getitem_1.sum(); getitem_1 = None
1073+
add_4 = add_3 + sum_6; add_3 = sum_6 = None
1074+
return (add_4,)
1075+
""",
1076+
)
1077+
_stable_topological_sort(
1078+
graph, torch._dynamo.graph_deduplication.last_node_to_additional_deps
10591079
)
1060-
_stable_topological_sort(graph, additional_deps)
10611080
self.assertExpectedInline(
1062-
graph,
1081+
graph_code(graph),
10631082
"""\
1064-
graph():
1065-
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
1066-
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
1067-
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
1068-
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1069-
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
1070-
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
1071-
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
1072-
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
1073-
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
1074-
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1075-
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
1076-
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
1077-
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
1078-
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
1079-
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
1080-
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
1081-
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
1082-
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
1083-
return (add_4,)""",
1083+
1084+
1085+
1086+
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
1087+
subgraph_0 = self.subgraph_0
1088+
l_x_ = L_x_
1089+
l_y_ = L_y_
1090+
x0 = l_x_.view((10, 10))
1091+
o0 = x0.view((10, 10)); x0 = None
1092+
x0_1 = l_x_.view((10, 10))
1093+
o1 = x0_1.view((10, 10)); x0_1 = None
1094+
add_2 = o0 + o1; o0 = o1 = None
1095+
add_ = l_x_.add_(l_x_); add_ = None
1096+
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
1097+
mul_ = l_y_.mul_(l_y_); mul_ = None
1098+
getitem = invoke_subgraph[0]; invoke_subgraph = None
1099+
sum_5 = getitem.sum(); getitem = None
1100+
add_3 = add_2 + sum_5; add_2 = sum_5 = None
1101+
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
1102+
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
1103+
sum_6 = getitem_1.sum(); getitem_1 = None
1104+
add_4 = add_3 + sum_6; add_3 = sum_6 = None
1105+
return (add_4,)
1106+
""",
10841107
)
10851108

10861109

0 commit comments

Comments
 (0)