Skip to content

Commit 83d4e52

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix test constant_prop_pass_for_add (#1920)
Summary: Pull Request resolved: #1920 Reviewed By: mergennachin Differential Revision: D53629151 fbshipit-source-id: f9b1862e8b45a9ca12ddcac916ca7f76930da974
1 parent 7128b3e commit 83d4e52

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

exir/tests/test_passes.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,26 +1110,28 @@ def forward(self, x):
11101110

11111111
def test_constant_prop_pass_for_add(self) -> None:
11121112
class Add(torch.nn.Module):
1113-
def add(self, x: torch.Tensor) -> torch.Tensor:
1113+
def forward(self, x: torch.Tensor) -> torch.Tensor:
11141114
return x + 3
11151115

11161116
add = Add()
11171117

11181118
edge = to_edge(export(add, (torch.ones(1),)))
11191119
edge = edge.transform([ScalarToTensorPass(), RemoveMixedTypeOperators()])
1120-
edge.exported_program = lift_constant_tensor_pass(edge.exported_program())
1120+
exported_program = lift_constant_tensor_pass(edge.exported_program())
11211121

11221122
# Check there is a lifted tensor followed by a to_copy node
11231123
FileCheck().check("_lifted_tensor_constant0").check(
1124-
"torch.ops.aten._to_copy.default"
1125-
).run(edge.exported_program().graph_module.code)
1124+
"executorch_exir_dialects_edge__ops_aten__to_copy_default"
1125+
).run(exported_program.graph_module.code)
11261126

1127-
new_ep = constant_prop_pass(edge.exported_program())
1127+
new_ep = constant_prop_pass(exported_program)
11281128

11291129
# Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
11301130
FileCheck().check_not("_lifted_tensor_constant").check(
11311131
"_prop_tensor_constant1"
1132-
).check_not("torch.ops.aten._to_copy.default").run(new_ep.graph_module.code)
1132+
).check_not("executorch_exir_dialects_edge__ops_aten__to_copy_default").run(
1133+
new_ep.graph_module.code
1134+
)
11331135

11341136
def test_constant_prop_pass_for_parameter(self) -> None:
11351137
def count_additions(gm: torch.fx.GraphModule) -> int:

0 commit comments

Comments
 (0)