Skip to content

Commit bdb7f98

Browse files
Merge branch 'main' into add-cast-int64-placeholders-to-int32-pass
2 parents 1b203eb + dbac09c commit bdb7f98

File tree

4 files changed

+106
-9
lines changed

4 files changed

+106
-9
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4979,9 +4979,9 @@ def test_static_qwen2_5(self):
49794979
if "Error" in msg:
49804980
self.fail(msg["Error"])
49814981
else:
4982-
inference_speed_ref = {"SM8650": 110, "SM8750": 130}
4982+
inference_speed_ref = {"SM8650": 115, "SM8750": 155}
49834983
self.assertLessEqual(msg["wiki_ppl"], 15)
4984-
self.assertLessEqual(msg["pte_size"], 800000000) # 800mb
4984+
self.assertLessEqual(msg["pte_size"], 600000000) # 600mb
49854985
if self.model in inference_speed_ref:
49864986
self.assertGreaterEqual(
49874987
msg["inference_speed"], inference_speed_ref[self.model]

backends/transforms/remove_clone_ops.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class RemoveCloneOpsTransform(ExportPass):
2222

2323
clone_ops: Set[torch._ops.OpOverload] = {
2424
exir_ops.edge.aten.clone.default,
25+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
2526
}
2627

2728
def __init__(self) -> None:
@@ -34,12 +35,15 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
3435
if n.target not in self.clone_ops:
3536
continue
3637

37-
to_be_remove = n
38+
if self._is_non_identity_clone(n):
39+
continue
40+
41+
to_be_removed = n
3842
for user_n in list(n.users.keys()):
3943
user_n.replace_input_with(n, n.args[0])
4044
if n.args[0].target in _DEQUANT_OPS:
4145
dequant_nodes += [n.args[0]]
42-
graph_module.graph.erase_node(to_be_remove)
46+
graph_module.graph.erase_node(to_be_removed)
4347

4448
eliminate_dq_q(graph_module, dequant_nodes)
4549

@@ -48,3 +52,27 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4852
graph_module.recompile()
4953
dead_code_elimination_pass(graph_module)
5054
return PassResult(graph_module, True)
55+
56+
def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
57+
"""Return True if clone has modified memory layout or dim order."""
58+
59+
# aten.clone: check for memory_format changes
60+
if node.target == exir_ops.edge.aten.clone.default:
61+
memory_format = node.kwargs.get("memory_format")
62+
if memory_format in (None, torch.preserve_format):
63+
return False
64+
input_meta = node.args[0].meta
65+
return "val" in input_meta and not input_meta["val"].is_contiguous(
66+
memory_format=memory_format
67+
)
68+
69+
# _clone_dim_order: check for dim_order changes
70+
if node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default:
71+
input_meta = node.args[0].meta
72+
return (
73+
"val" in node.meta
74+
and "val" in input_meta
75+
and node.meta["val"].dim_order() != input_meta["val"].dim_order()
76+
)
77+
78+
return False

backends/transforms/test/test_remove_clone_ops.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,30 @@
88

99
import torch
1010
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
11+
from executorch.exir import EdgeCompileConfig, to_edge
1112
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.dim_order_utils import is_channel_last_dim_order
14+
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
15+
SimpleCloneChannelsLastModule,
16+
)
17+
from torch.export import export
1218
from torch.fx import GraphModule
1319
from torch.testing import FileCheck
1420
from torch.testing._internal.common_utils import TestCase
1521

1622

1723
class TestRemoveCloneOpsTransform(TestCase):
24+
# Clone ops can appear as either aten.clone or _clone_dim_order depending on the _skip_dim_order flag.
25+
# _skip_dim_order=True tests aten.clone
26+
# _skip_dim_order=False tests _clone_dim_order
27+
CLONE_OP_CASES = [
28+
(True, "executorch_exir_dialects_edge__ops_aten_clone_default"),
29+
(
30+
False,
31+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
32+
),
33+
]
34+
1835
def test_dq_clone_q_linear(self):
1936
"""
2037
Test RemoveCloneOpsTransform on a graph with d/q -> clone -> q -> linear pattern
@@ -123,6 +140,58 @@ def forward(self, x):
123140
transformed_gm.code
124141
)
125142

143+
def test_clone_non_identity_survives(self):
144+
"""Verify clone ops that modify memory_format are preserved by RemoveCloneOpsTransform."""
145+
146+
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
147+
model = SimpleCloneChannelsLastModule()
148+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.contiguous_format)
149+
150+
exported = export(model.eval(), (x,), strict=True)
151+
before_epm = to_edge(
152+
exported,
153+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
154+
)
155+
156+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
157+
158+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
159+
updated_epm.exported_program().graph_module.code
160+
)
161+
162+
expected = before_epm.exported_program().module()(x)
163+
actual = updated_epm.exported_program().module()(x)
164+
assert torch.allclose(actual, expected)
165+
assert is_channel_last_dim_order(actual)
166+
167+
def test_clone_identity_removed(self):
168+
"""Verify identity clone ops are removed by RemoveCloneOpsTransform."""
169+
170+
for skip_dim_order, clone_op_str in self.CLONE_OP_CASES:
171+
model = SimpleCloneChannelsLastModule()
172+
x = torch.randn(3, 4, 5, 6).to(memory_format=torch.channels_last)
173+
174+
exported = export(model.eval(), (x,), strict=True)
175+
before_epm = to_edge(
176+
exported,
177+
compile_config=EdgeCompileConfig(_skip_dim_order=skip_dim_order),
178+
)
179+
180+
FileCheck().check_count(clone_op_str, 1, exactly=True).run(
181+
before_epm.exported_program().graph_module.code
182+
)
183+
184+
updated_epm = before_epm.transform([RemoveCloneOpsTransform()])
185+
186+
FileCheck().check_not(clone_op_str).run(
187+
updated_epm.exported_program().graph_module.code
188+
)
189+
190+
expected = before_epm.exported_program().module()(x)
191+
actual = updated_epm.exported_program().module()(x)
192+
assert torch.allclose(actual, expected)
193+
assert is_channel_last_dim_order(actual)
194+
126195

127196
if __name__ == "__main__":
128197
unittest.main()

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ class Qwen2_5_0_5B(LLMModelConfig):
211211

212212
num_sharding = 1
213213
# quant config
214-
ptq = QuantDtype.use_16a8w
215-
group_size = None
214+
ptq = QuantDtype.use_16a4w_block
215+
group_size = 16
216216
masked_softmax = True
217217
r1 = False
218218
r2 = False
@@ -233,13 +233,13 @@ class Qwen2_5_1_5B(LLMModelConfig):
233233

234234
num_sharding = 1
235235
# quant config
236-
ptq = QuantDtype.use_16a8w
237-
group_size = None
236+
ptq = QuantDtype.use_16a4w_block
237+
group_size = 16
238238
masked_softmax = True
239239
r1 = False
240240
r2 = False
241241
r3 = True
242-
custom_annotation = ()
242+
custom_annotation = (annotate_output_16a8w,)
243243

244244

245245
@register_llm_model("qwen3-0_6b")

0 commit comments

Comments
 (0)