File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -76,16 +76,15 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
7676 """Check if a node is an inplace node."""
7777 return (
7878 node .op == "call_function"
79- and isinstance (node .target , torch . _ops . OpOverload )
79+ and hasattr (node .target , "_schema" )
8080 and is_inplace_variant (
81- node .target ._schema .name , node .target ._schema .overload_name
81+ node .target ._schema .name , node .target ._schema .overload_name #pyre-ignore
8282 )
8383 )
8484
8585
8686def _inplace_lineage (
8787 output_arg : torch .fx .Node ,
88- gm : torch .fx .GraphModule ,
8988 gs : ExportGraphSignature ,
9089 kind : SchemaKind ,
9190) -> bool :
@@ -160,7 +159,6 @@ def insert_write_back_for_buffers_pass(
160159 # if the arg and target are not the same, we add a copy_ node.
161160 not _inplace_lineage (
162161 output_node .args [0 ][i ],
163- gm ,
164162 ep .graph_signature ,
165163 ep .graph_signature .output_specs [i ].kind ,
166164 )
Original file line number Diff line number Diff line change @@ -61,6 +61,15 @@ def forward(
6161
6262 self .assertIsNotNone (index_put_node , "Should find an index_put_ node" )
6363
64+ # Find the copy_ node
65+ copy_node = None
66+ for node in et .exported_program ().graph .nodes :
67+ if node .op == "call_function" and "copy_" in str (node .target ):
68+ copy_node = node
69+ break
70+
71+ self .assertIsNone (copy_node , "Shouldn't find an copy_ node" )
72+
6473 e = _load_for_executorch_from_buffer (et .buffer )
6574 self .assertTrue (
6675 torch .allclose (
You can’t perform that action at this time.
0 commit comments