File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -72,16 +72,15 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
7272 """Check if a node is an inplace node."""
7373 return (
7474 node .op == "call_function"
75- and isinstance (node .target , torch . _ops . OpOverload )
75+ and hasattr (node .target , "_schema" )
7676 and is_inplace_variant (
77- node .target ._schema .name , node .target ._schema .overload_name
77+ node .target ._schema .name , node .target ._schema .overload_name # pyre-ignore
7878 )
7979 )
8080
8181
8282def _inplace_lineage (
8383 output_arg : torch .fx .Node ,
84- gm : torch .fx .GraphModule ,
8584 gs : ExportGraphSignature ,
8685 kind : SchemaKind ,
8786) -> bool :
@@ -152,7 +151,6 @@ def insert_write_back_for_buffers_pass(
152151 # if the arg and target are not the same, we add a copy_ node.
153152 not _inplace_lineage (
154153 output_node .args [0 ][i ],
155- gm ,
156154 ep .graph_signature ,
157155 ep .graph_signature .output_specs [i ].kind ,
158156 )
Original file line number Diff line number Diff line change @@ -61,6 +61,15 @@ def forward(
6161
6262 self .assertIsNotNone (index_put_node , "Should find an index_put_ node" )
6363
64+ # Find the copy_ node
65+ copy_node = None
66+ for node in et .exported_program ().graph .nodes :
67+ if node .op == "call_function" and "copy_" in str (node .target ):
68+ copy_node = node
69+ break
70+
71+ self .assertIsNone (copy_node , "Shouldn't find an copy_ node" )
72+
6473 e = _load_for_executorch_from_buffer (et .buffer )
6574 self .assertTrue (
6675 torch .allclose (
You can’t perform that action at this time.
0 commit comments