Skip to content

Commit 2cf34da

Browse files
Add unit test.
1 parent 693cc29 commit 2cf34da

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

exir/tests/test_passes.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,36 +1289,41 @@ class MutableStateModule(torch.nn.Module):
12891289
def __init__(self):
12901290
super().__init__()
12911291
self.register_buffer("state", torch.zeros(1))
1292+
self.register_buffer("direct_copy_from_input", torch.zeros(1))
12921293

12931294
def forward(self, x):
12941295
y = x + self.state
12951296
self.state.add_(1)
1297+
self.direct_copy_from_input.copy_(x)
12961298
return y
12971299

12981300
model = to_edge(export(MutableStateModule(), (torch.zeros(1),), strict=True))
12991301
self.assertEqual(count_copies(model.exported_program().graph_module), 0)
13001302
# Before
13011303
# graph():
1302-
# %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1303-
# %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1304-
# %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1305-
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1306-
# %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1307-
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1308-
# return (aten_add_tensor_1, aten_add_tensor)
1304+
# %b_state : [num_users=2] = placeholder[target=b_state]
1305+
# %b_direct_copy_from_input : [num_users=0] = placeholder[target=b_direct_copy_from_input]
1306+
# %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1307+
# %x : [num_users=2] = placeholder[target=x]
1308+
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1309+
# %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1310+
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1311+
# return (aten_add_tensor_1, x, aten_add_tensor)
13091312
gm, _ = insert_write_back_for_buffers_pass(model.exported_program())
13101313

13111314
# After
13121315
# graph():
1313-
# %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1314-
# %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1315-
# %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1316-
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1317-
# %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1318-
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1319-
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1320-
# return (copy__default, aten_add_tensor)
1321-
self.assertEqual(count_copies(gm), 1)
1316+
# %b_state : [num_users=3] = placeholder[target=b_state]
1317+
# %b_direct_copy_from_input : [num_users=1] = placeholder[target=b_direct_copy_from_input]
1318+
# %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1319+
# %x : [num_users=2] = placeholder[target=x]
1320+
# %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1321+
# %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1322+
# %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1323+
# %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_state, %aten_add_tensor_1), kwargs = {})
1324+
# %copy__default_1 : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_direct_copy_from_input, %x), kwargs = {})
1325+
# return (copy__default, copy__default_1, aten_add_tensor)
1326+
self.assertEqual(count_copies(gm), 2)
13221327

13231328
def test_remove_quantized_op_noop_pass(self) -> None:
13241329
class TestAddSliceNoop(torch.nn.Module):

0 commit comments

Comments
 (0)