Skip to content

Commit 3dda804

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
fix is_inplace_node check
Summary: Oops it could be an edge op which failed the opOverload check. Switch to just hasattr Differential Revision: D77462717
1 parent ab4217e commit 3dda804

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,30 +76,31 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
7676
"""Check if a node is an inplace node."""
7777
return (
7878
node.op == "call_function"
79-
and isinstance(node.target, torch._ops.OpOverload)
79+
and hasattr(node.target, "_schema")
8080
and is_inplace_variant(
81-
node.target._schema.name, node.target._schema.overload_name
81+
node.target._schema.name, node.target._schema.overload_name #pyre-ignore
8282
)
8383
)
8484

8585

8686
def _inplace_lineage(
8787
output_arg: torch.fx.Node,
88-
gm: torch.fx.GraphModule,
8988
gs: ExportGraphSignature,
9089
kind: SchemaKind,
9190
) -> bool:
9291
"""
9392
Walk the graph backwards to see if output_arg is ultimately the same as an input.
9493
"""
9594
if kind != OutputKind.BUFFER_MUTATION and kind != OutputKind.USER_INPUT_MUTATION:
95+
print("Wrong kind")
9696
return False
9797

9898
while output_arg.op != "placeholder":
9999
if _is_inplace_node(output_arg):
100100
# From looking at native_functions.yaml, inplace ops always have self as the first arg
101101
output_arg = output_arg.args[0] # pyre-ignore
102102
else:
103+
print("Not inplace " , output_arg)
103104
return False
104105

105106
# If the output arg was a buffer then it needs to reach a buffer placeholder
@@ -160,7 +161,6 @@ def insert_write_back_for_buffers_pass(
160161
# if the arg and target are not the same, we add a copy_ node.
161162
not _inplace_lineage(
162163
output_node.args[0][i],
163-
gm,
164164
ep.graph_signature,
165165
ep.graph_signature.output_specs[i].kind,
166166
)

exir/tests/test_reinplace_pass.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def forward(
6161

6262
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
6363

64+
# Find the copy_ node
65+
copy_node = None
66+
for node in et.exported_program().graph.nodes:
67+
if node.op == "call_function" and "copy_" in str(node.target):
68+
copy_node = node
69+
break
70+
71+
self.assertIsNone(copy_node, "Shoulnt find an copy_ node")
72+
6473
e = _load_for_executorch_from_buffer(et.buffer)
6574
self.assertTrue(
6675
torch.allclose(

0 commit comments

Comments
 (0)