File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -72,16 +72,15 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
72
72
"""Check if a node is an inplace node."""
73
73
return (
74
74
node .op == "call_function"
75
- and isinstance (node .target , torch . _ops . OpOverload )
75
+ and hasattr (node .target , "_schema" )
76
76
and is_inplace_variant (
77
- node .target ._schema .name , node .target ._schema .overload_name
77
+ node .target ._schema .name , node .target ._schema .overload_name # pyre-ignore
78
78
)
79
79
)
80
80
81
81
82
82
def _inplace_lineage (
83
83
output_arg : torch .fx .Node ,
84
- gm : torch .fx .GraphModule ,
85
84
gs : ExportGraphSignature ,
86
85
kind : SchemaKind ,
87
86
) -> bool :
@@ -152,7 +151,6 @@ def insert_write_back_for_buffers_pass(
152
151
# if the arg and target are not the same, we add a copy_ node.
153
152
not _inplace_lineage (
154
153
output_node .args [0 ][i ],
155
- gm ,
156
154
ep .graph_signature ,
157
155
ep .graph_signature .output_specs [i ].kind ,
158
156
)
Original file line number Diff line number Diff line change @@ -61,6 +61,15 @@ def forward(
61
61
62
62
self .assertIsNotNone (index_put_node , "Should find an index_put_ node" )
63
63
64
+ # Find the copy_ node
65
+ copy_node = None
66
+ for node in et .exported_program ().graph .nodes :
67
+ if node .op == "call_function" and "copy_" in str (node .target ):
68
+ copy_node = node
69
+ break
70
+
71
+ self .assertIsNone (copy_node , "Shouldn't find an copy_ node" )
72
+
64
73
e = _load_for_executorch_from_buffer (et .buffer )
65
74
self .assertTrue (
66
75
torch .allclose (
You can’t perform that action at this time.
0 commit comments