Skip to content

Commit 62bced0

Browse files
committed
Register _clone_dim_order and add operator tests
1 parent f8d18ec commit 62bced0

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

exir/passes/dim_order_ops_registry.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@
2828
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
2929
)
3030

31+
lib.define(
32+
"_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
33+
)
34+
35+
lib.define(
36+
"_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
37+
)
38+
3139

3240
def _op_impl(target, *args, **kwargs):
3341
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
@@ -57,6 +65,16 @@ def _empty_dim_order_out_impl(*args, **kwargs):
5765
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)
5866

5967

68+
@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd")
69+
def _clone_dim_order_impl(*args, **kwargs):
70+
return _op_impl(torch.ops.aten.clone.default, *args, **kwargs)
71+
72+
73+
@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd")
74+
def _clone_dim_order_out_impl(*args, **kwargs):
75+
return _op_impl(torch.ops.aten.clone.out, *args, **kwargs)
76+
77+
6078
"""
6179
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
6280
"""

exir/tests/test_memory_format_ops_pass.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
MemoryFormatOpsPassTestUtils,
2929
MemoryFormatTestSet,
3030
PropagateToCopyChannalsLastModule,
31-
SimpleCloneChannelsLastModule,
32-
SimpleCloneContiguousModule,
3331
SimpleEmptyChannelLastModule,
3432
SimpleEmptyContiguoustModule,
3533
SimpleToCopyChannelsLastModule,
@@ -93,36 +91,40 @@ def test_op_empty_replacement_contiguous(self) -> None:
9391
),
9492
)
9593

96-
def test_op_clone_replacement_contiguous(self) -> None:
97-
model = SimpleCloneContiguousModule()
98-
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
99-
self,
100-
MemoryFormatTestSet(
101-
module=model.eval(),
102-
op=torch.ops.aten.clone.default,
103-
sample_input=(
104-
torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last),
105-
),
106-
target_memory_format=torch.contiguous_format,
107-
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
108-
),
94+
def test_op_clone_dim_order_preserves_channels_last(self):
95+
x = torch.randn(2, 3, 4, 5).to(memory_format=torch.channels_last)
96+
y = torch.ops.dim_order_ops._clone_dim_order.default(x)
97+
98+
assert y.is_contiguous(
99+
memory_format=torch.channels_last
100+
), "_clone_dim_order output is not in channels_last memory format."
101+
assert torch.allclose(x, y)
102+
103+
def test_op_clone_dim_order_to_contiguous(self):
104+
x = torch.randn(2, 3, 4, 5).to(memory_format=torch.channels_last)
105+
contiguous_dim_order = get_dim_order(torch.contiguous_format, x.dim())
106+
y = torch.ops.dim_order_ops._clone_dim_order.default(
107+
x, dim_order=contiguous_dim_order
109108
)
110109

111-
def test_op_clone_replacement_channels_last(self) -> None:
112-
model = SimpleCloneChannelsLastModule()
113-
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
114-
self,
115-
MemoryFormatTestSet(
116-
module=model.eval(),
117-
op=torch.ops.aten.clone.default,
118-
sample_input=(
119-
torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format),
120-
),
121-
target_memory_format=torch.channels_last,
122-
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
123-
),
110+
assert (
111+
y.is_contiguous()
112+
), "_clone_dim_order output is not in contiguous memory format"
113+
assert torch.allclose(x, y)
114+
115+
def test_op_clone_dim_order_out_to_channels_last(self):
116+
x = torch.randn(2, 3, 4, 5).contiguous()
117+
y = torch.empty_like(x, memory_format=torch.channels_last)
118+
channels_last_dim_order = get_dim_order(torch.channels_last, y.dim())
119+
torch.ops.dim_order_ops._clone_dim_order.out(
120+
x, dim_order=channels_last_dim_order, out=y
124121
)
125122

123+
assert y.is_contiguous(
124+
memory_format=torch.channels_last
125+
), "_clone_dim_order output is not in channels_last memory format"
126+
assert torch.allclose(x, y)
127+
126128
def test_op_dim_order_update(self) -> None:
127129
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
128130
self,

0 commit comments

Comments
 (0)