Skip to content

Commit a909b83

Browse files
fix is_inplace_node check
Differential Revision: D77462717 Pull Request resolved: #12071
1 parent dc7fd75 commit a909b83

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,15 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
7272
"""Check if a node is an inplace node."""
7373
return (
7474
node.op == "call_function"
75-
and isinstance(node.target, torch._ops.OpOverload)
75+
and hasattr(node.target, "_schema")
7676
and is_inplace_variant(
77-
node.target._schema.name, node.target._schema.overload_name
77+
node.target._schema.name, node.target._schema.overload_name # pyre-ignore
7878
)
7979
)
8080

8181

8282
def _inplace_lineage(
8383
output_arg: torch.fx.Node,
84-
gm: torch.fx.GraphModule,
8584
gs: ExportGraphSignature,
8685
kind: SchemaKind,
8786
) -> bool:
@@ -152,7 +151,6 @@ def insert_write_back_for_buffers_pass(
152151
# if the arg and target are not the same, we add a copy_ node.
153152
not _inplace_lineage(
154153
output_node.args[0][i],
155-
gm,
156154
ep.graph_signature,
157155
ep.graph_signature.output_specs[i].kind,
158156
)

exir/tests/test_reinplace_pass.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def forward(
6161

6262
self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
6363

64+
# Find the copy_ node
65+
copy_node = None
66+
for node in et.exported_program().graph.nodes:
67+
if node.op == "call_function" and "copy_" in str(node.target):
68+
copy_node = node
69+
break
70+
71+
self.assertIsNone(copy_node, "Shouldn't find an copy_ node")
72+
6473
e = _load_for_executorch_from_buffer(et.buffer)
6574
self.assertTrue(
6675
torch.allclose(

0 commit comments

Comments
 (0)