@@ -1289,36 +1289,41 @@ class MutableStateModule(torch.nn.Module):
12891289 def __init__ (self ):
12901290 super ().__init__ ()
12911291 self .register_buffer ("state" , torch .zeros (1 ))
1292+ self .register_buffer ("direct_copy_from_input" , torch .zeros (1 ))
12921293
12931294 def forward (self , x ):
12941295 y = x + self .state
12951296 self .state .add_ (1 )
1297+ self .direct_copy_from_input .copy_ (x )
12961298 return y
12971299
12981300 model = to_edge (export (MutableStateModule (), (torch .zeros (1 ),), strict = True ))
12991301 self .assertEqual (count_copies (model .exported_program ().graph_module ), 0 )
13001302 # Before
13011303 # graph():
1302- # %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1303- # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1304- # %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1305- # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1306- # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1307- # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1308- # return (aten_add_tensor_1, aten_add_tensor)
1304+ # %b_state : [num_users=2] = placeholder[target=b_state]
1305+ # %b_direct_copy_from_input : [num_users=0] = placeholder[target=b_direct_copy_from_input]
1306+ # %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1307+ # %x : [num_users=2] = placeholder[target=x]
1308+ # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1309+ # %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1310+ # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1311+ # return (aten_add_tensor_1, x, aten_add_tensor)
13091312 gm , _ = insert_write_back_for_buffers_pass (model .exported_program ())
13101313
13111314 # After
13121315 # graph():
1313- # %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1314- # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1315- # %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1316- # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1317- # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1318- # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1319- # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1320- # return (copy__default, aten_add_tensor)
1321- self .assertEqual (count_copies (gm ), 1 )
1316+ # %b_state : [num_users=3] = placeholder[target=b_state]
1317+ # %b_direct_copy_from_input : [num_users=1] = placeholder[target=b_direct_copy_from_input]
1318+ # %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1319+ # %x : [num_users=2] = placeholder[target=x]
1320+ # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1321+ # %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1322+ # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1323+ # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_state, %aten_add_tensor_1), kwargs = {})
1324+ # %copy__default_1 : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_direct_copy_from_input, %x), kwargs = {})
1325+ # return (copy__default, copy__default_1, aten_add_tensor)
1326+ self .assertEqual (count_copies (gm ), 2 )
13221327
13231328 def test_remove_quantized_op_noop_pass (self ) -> None :
13241329 class TestAddSliceNoop (torch .nn .Module ):
0 commit comments