Skip to content

Commit ffd1549

Browse files
committed
Add aten.clone memory_format check in RemoveCloneOpsTransform
1 parent ad74bdf commit ffd1549

File tree

2 files changed

+82
-49
lines changed

2 files changed

+82
-49
lines changed

backends/transforms/remove_clone_ops.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,46 @@ def remove_clone_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
1515
"""
1616
Remove clone op nodes that have the same dim_order as their input, and replace their uses with the input node.
1717
"""
18-
clone_op = exir_ops.edge.aten.clone.default
19-
clone_dim_order_op = exir_ops.edge.dim_order_ops._clone_dim_order.default
20-
2118
for node in graph.nodes:
2219
if node.op != "call_function":
2320
continue
2421

25-
# Identify clone_dim_order ops with unchanged memory layout.
26-
unchanged_layout_clone = (
27-
node.target == clone_dim_order_op
28-
and "val" in node.meta
29-
and "val" in node.args[0].meta
30-
and node.meta["val"].dim_order() == node.args[0].meta["val"].dim_order()
31-
)
32-
33-
if node.target == clone_op or unchanged_layout_clone:
22+
if is_unchanged_clone(node) or is_unchanged_dim_order_clone(node):
3423
with graph.inserting_after(node):
3524
node.replace_all_uses_with(node.args[0])
3625

3726
graph.eliminate_dead_code()
3827
return graph
3928

4029

30+
def is_unchanged_clone(node: torch.fx.Node) -> bool:
31+
"""Determine if aten.clone has unchanged memory format."""
32+
if node.target != exir_ops.edge.aten.clone.default:
33+
return False
34+
35+
memory_format = node.kwargs.get("memory_format")
36+
if memory_format in (None, torch.preserve_format):
37+
return True
38+
39+
input_meta = node.args[0].meta
40+
return "val" in input_meta and input_meta["val"].is_contiguous(
41+
memory_format=memory_format
42+
)
43+
44+
45+
def is_unchanged_dim_order_clone(node: torch.fx.Node) -> bool:
46+
"""Determine if _clone_dim_order has unchanged dim order."""
47+
if node.target != exir_ops.edge.dim_order_ops._clone_dim_order.default:
48+
return False
49+
50+
input_meta = node.args[0].meta
51+
return (
52+
"val" in node.meta
53+
and "val" in input_meta
54+
and node.meta["val"].dim_order() == input_meta["val"].dim_order()
55+
)
56+
57+
4158
class RemoveCloneOpsTransform(ExportPass):
4259
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4360
graph_module.graph = remove_clone_ops(graph_module.graph)

exir/tests/test_memory_format_ops_pass.py

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -326,56 +326,72 @@ def call_operator(self, op, args, kwargs, meta):
326326
self.assertTrue(is_contiguous_dim_order(expected))
327327

328328
def test_op_clone_replacement_channels_last_survives(self):
329-
_clone_dim_order_op_str = (
330-
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
331-
)
329+
clone_op_cases = [
330+
# Case testing aten.clone by setting _skip_dim_order to True
331+
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
332+
# Case testing _clone_dim_order by setting _skip_dim_order to False
333+
(
334+
False,
335+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
336+
),
337+
]
332338

333-
model = SimpleCloneChannelsLastModule()
334-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
339+
for skip_dim_order, clone_op_str in clone_op_cases:
340+
model = SimpleCloneChannelsLastModule()
341+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
335342

336-
exported = export(model.eval(), (x,), strict=True)
337-
before_epm = to_edge(
338-
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
339-
)
343+
exported = export(model.eval(), (x,), strict=True)
344+
before_epm = to_edge(
345+
exported,
346+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
347+
)
340348

341-
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
349+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
342350

343-
FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run(
344-
updated_epm.exported_program().graph_module.code
345-
)
351+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
352+
updated_epm.exported_program().graph_module.code
353+
)
346354

347-
expected = before_epm.exported_program().module()(x)
348-
actual = updated_epm.exported_program().module()(x)
349-
assert torch.allclose(actual, expected)
350-
assert is_channel_last_dim_order(actual)
355+
expected = before_epm.exported_program().module()(x)
356+
actual = updated_epm.exported_program().module()(x)
357+
assert torch.allclose(actual, expected)
358+
assert is_channel_last_dim_order(actual)
351359

352360
def test_op_clone_without_transformation_removed(self):
353-
_clone_dim_order_op_str = (
354-
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
355-
)
361+
clone_op_cases = [
362+
# Case testing aten.clone by setting _skip_dim_order to True
363+
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
364+
# Case testing _clone_dim_order by setting _skip_dim_order to False
365+
(
366+
False,
367+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
368+
),
369+
]
356370

357-
model = SimpleCloneChannelsLastModule()
358-
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
371+
for skip_dim_order, clone_op_str in clone_op_cases:
372+
model = SimpleCloneChannelsLastModule()
373+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
359374

360-
exported = export(model.eval(), (x,), strict=True)
361-
before_epm = to_edge(
362-
exported, compile_config=EdgeCompileConfig(_skip_dim_order=False)
363-
)
375+
exported = export(model.eval(), (x,), strict=True)
376+
before_epm = to_edge(
377+
exported,
378+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
379+
)
364380

365-
FileCheck().check_count(_clone_dim_order_op_str, 1, exactly=True).run(
366-
before_epm.exported_program().graph_module.code
367-
)
381+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
382+
before_epm.exported_program().graph_module.code
383+
)
368384

369-
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
385+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
370386

371-
FileCheck().check_not(_clone_dim_order_op_str).run(
372-
updated_epm.exported_program().graph_module.code
373-
)
387+
FileCheck().check_not(clone_op_str).run(
388+
updated_epm.exported_program().graph_module.code
389+
)
374390

375-
expected = before_epm.exported_program().module()(x)
376-
actual = updated_epm.exported_program().module()(x)
377-
assert torch.allclose(actual, expected)
378-
assert is_channel_last_dim_order(actual)
391+
expected = before_epm.exported_program().module()(x)
392+
actual = updated_epm.exported_program().module()(x)
393+
assert torch.allclose(actual, expected)
394+
assert is_channel_last_dim_order(actual)
379395

380396
def test_resnet18(self) -> None:
381397
model = torchvision.models.resnet18()

0 commit comments

Comments
 (0)