Skip to content

Commit 0160063

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
fix is_inplace_node check (#12071)
Summary: Oops it could be an edge op which failed the opOverload check. Switch to just hasattr Reviewed By: larryliu0820 Differential Revision: D77462717
1 parent ab4217e commit 0160063

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,15 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
7676
"""Check if a node is an inplace node."""
7777
return (
7878
node.op == "call_function"
79-
and isinstance(node.target, torch._ops.OpOverload)
79+
and hasattr(node.target, "_schema")
8080
and is_inplace_variant(
81-
node.target._schema.name, node.target._schema.overload_name
81+
node.target._schema.name, node.target._schema.overload_name #pyre-ignore
8282
)
8383
)
8484

8585

8686
def _inplace_lineage(
8787
output_arg: torch.fx.Node,
88-
gm: torch.fx.GraphModule,
8988
gs: ExportGraphSignature,
9089
kind: SchemaKind,
9190
) -> bool:
@@ -160,7 +159,6 @@ def insert_write_back_for_buffers_pass(
160159
# if the arg and target are not the same, we add a copy_ node.
161160
not _inplace_lineage(
162161
output_node.args[0][i],
163-
gm,
164162
ep.graph_signature,
165163
ep.graph_signature.output_specs[i].kind,
166164
)

exir/tests/test_reinplace_pass.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def forward(
6161

6262
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
6363

64+
# Find the copy_ node
65+
copy_node = None
66+
for node in et.exported_program().graph.nodes:
67+
if node.op == "call_function" and "copy_" in str(node.target):
68+
copy_node = node
69+
break
70+
71+
self.assertIsNone(copy_node, "Shouldn't find an copy_ node")
72+
6473
e = _load_for_executorch_from_buffer(et.buffer)
6574
self.assertTrue(
6675
torch.allclose(

0 commit comments

Comments
 (0)