Skip to content

Commit 877119f

Browse files
committed
Register clone_dim_order op; add test for op replacement
1 parent d1c87e4 commit 877119f

File tree

3 files changed

+54
-0
lines changed

3 files changed

+54
-0
lines changed

exir/passes/dim_order_ops_registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@
2828
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
2929
)
3030

31+
lib.define(
32+
"_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
33+
)
34+
35+
lib.define(
36+
"_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
37+
)
38+
3139

3240
def _op_impl(target, *args, **kwargs):
3341
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
@@ -57,12 +65,23 @@ def _empty_dim_order_out_impl(*args, **kwargs):
5765
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)
5866

5967

68+
@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd")
69+
def _clone_dim_order_impl(*args, **kwargs):
70+
return _op_impl(torch.ops.aten.clone.default, *args, **kwargs)
71+
72+
73+
@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd")
74+
def _clone_dim_order_out_impl(*args, **kwargs):
75+
return _op_impl(torch.ops.aten.clone.out, *args, **kwargs)
76+
77+
6078
"""
6179
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
6280
"""
6381
DimOrderOpsMap = {
6482
exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
6583
exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default,
84+
exir_ops.edge.aten.clone.default: exir_ops.edge.dim_order_ops._clone_dim_order.default,
6685
}
6786

6887
"""

exir/tests/test_memory_format_ops_pass.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
MemoryFormatOpsPassTestUtils,
2929
MemoryFormatTestSet,
3030
PropagateToCopyChannalsLastModule,
31+
SimpleCloneChannelsLastModule,
3132
SimpleEmptyChannelLastModule,
3233
SimpleEmptyContiguoustModule,
3334
SimpleToCopyChannelsLastModule,
@@ -389,3 +390,17 @@ def test_mobilenet_v3_xnnpack(self) -> None:
389390
rtol=1e-3,
390391
),
391392
)
393+
394+
def test_op_clone_dim_order_replacement(self):
395+
model = SimpleCloneChannelsLastModule()
396+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
397+
clone_dim_order_op_str = (
398+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
399+
)
400+
401+
exported = export(model.eval(), (x,), strict=True)
402+
epm = to_edge(exported, compile_config=EdgeCompileConfig(_skip_dim_order=False))
403+
404+
FileCheck().check_count(clone_dim_order_op_str, 1, exactly=True).run(
405+
epm.exported_program().graph_module.code
406+
)

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
"torch.ops.aten.empty.memory_format",
3939
"executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default",
4040
),
41+
torch.ops.aten.clone.default: (
42+
"torch.ops.aten.clone.default",
43+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
44+
),
4145
}
4246

4347

@@ -70,6 +74,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7074
return x.to(dtype=torch.double, memory_format=torch.channels_last)
7175

7276

77+
class SimpleCloneContiguousModule(torch.nn.Module):
78+
def __init__(self):
79+
super().__init__()
80+
81+
def forward(self, x: torch.Tensor) -> torch.Tensor:
82+
return x.clone(memory_format=torch.contiguous_format)
83+
84+
85+
class SimpleCloneChannelsLastModule(torch.nn.Module):
86+
def __init__(self):
87+
super().__init__()
88+
89+
def forward(self, x: torch.Tensor) -> torch.Tensor:
90+
return x.clone(memory_format=torch.channels_last)
91+
92+
7393
class SimpleEmptyContiguoustModule(torch.nn.Module):
7494
def __init__(self):
7595
super().__init__()

0 commit comments

Comments
 (0)