Skip to content

Commit dec6aa4

Browse files
authored
Merge branch 'main' into order2
2 parents a5aa40a + 7a7e939 commit dec6aa4

38 files changed

+1051
-134
lines changed

.ci/scripts/test_model.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ test_model() {
9797
bash examples/models/llava/install_requirements.sh
9898
STRICT="--no-strict"
9999
fi
100-
if [[ "${MODEL_NAME}" == "qwen2_5" ]]; then
100+
if [[ "${MODEL_NAME}" == "qwen2_5_1_5b" ]]; then
101101
# Install requirements for export_llama
102102
bash examples/models/llama/install_requirements.sh
103103
# Test export_llm script: python3 -m extension.llm.export.export_llm.

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ jobs:
176176
- model: phi_4_mini
177177
backend: portable
178178
runner: linux.arm64.m7g.4xlarge
179-
- model: qwen2_5
179+
- model: qwen2_5_1_5b
180180
backend: portable
181181
runner: linux.arm64.2xlarge
182182
- model: llama3_2_vision_encoder

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ To get started you can:
5252

5353
- Visit the [Step by Step Tutorial](https://pytorch.org/executorch/stable/getting-started.html) to get things running locally and deploy a model to a device
5454
- Use this [Colab Notebook](https://colab.research.google.com/drive/1qpxrXC3YdJQzly3mRg-4ayYiOjC6rue3?usp=sharing) to start playing around right away
55-
- Jump straight into LLM use cases by following specific instructions for popular open-source models such as [Llama](examples/models/llama/README.md), [Qwen 3](examples/models/qwen3/README.md), [Phi-4-mini](examples/models/phi_4_mini/README.md), and [Llava](examples/models/llava/README.md)
55+
- Jump straight into LLM use cases by following specific instructions for popular open-source models such as [Llama](examples/models/llama/README.md), [Qwen 3](examples/models/qwen3/README.md), [Phi-4-mini](examples/models/phi_4_mini/README.md), [Llava](examples/models/llava/README.md), [Voxtral](examples/models/voxtral/README.md), and [LFM2](examples/models/lfm2/README.md).
5656

5757
## Feedback and Engagement
5858

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3838
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3939
from .decompose_div_pass import DecomposeDivPass # noqa
40+
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
4041
from .decompose_elu_pass import DecomposeEluPass # noqa
4142
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4243
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DecomposeCosineSimilarityPass,
4343
DecomposeCumsumPass,
4444
DecomposeDivPass,
45+
DecomposeDivTensorModePass,
4546
DecomposeEluPass,
4647
DecomposeEmbeddingPass,
4748
DecomposeExpm1Pass,
@@ -211,6 +212,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
211212
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
212213
)
213214
self.add_pass(DecomposeNotEqualPass())
215+
self.add_pass(DecomposeDivTensorModePass())
214216
self.add_pass(DecomposeDivPass())
215217
self.add_pass(DecomposeSoftmaxPass())
216218
self.add_pass(DecomposeGeluPass())
@@ -289,6 +291,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
289291
self.add_pass(DecomposeNotEqualPass())
290292
self.add_pass(DecomposeCosineSimilarityPass())
291293
self.add_pass(DecomposeGluPass())
294+
self.add_pass(DecomposeDivTensorModePass())
292295
self.add_pass(DecomposeDivPass())
293296
self.add_pass(DecomposeLeakyReLUPass())
294297
self.add_pass(DecomposeLinearVectorNormPass())
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
edge_div_mode_ops = (exir_ops.edge.aten.div.Tensor_mode,)
13+
aten_div_mode_ops = (torch.ops.aten.div.Tensor_mode,)
14+
15+
edge_unary = {
16+
"div": exir_ops.edge.aten.div.Tensor,
17+
"floor": exir_ops.edge.aten.floor.default,
18+
"ceil": exir_ops.edge.aten.ceil.default,
19+
"full": exir_ops.edge.aten.full.default,
20+
"lt": exir_ops.edge.aten.lt.Tensor,
21+
"where": exir_ops.edge.aten.where.self,
22+
}
23+
24+
aten_unary = {
25+
"div": torch.ops.aten.div.Tensor,
26+
"floor": torch.ops.aten.floor.default,
27+
"ceil": torch.ops.aten.ceil.default,
28+
"full": torch.ops.aten.full.default,
29+
"lt": torch.ops.aten.lt.Tensor,
30+
"where": torch.ops.aten.where.self,
31+
}
32+
33+
34+
def _get_opset(op):
35+
if op in edge_div_mode_ops:
36+
return edge_unary
37+
if op in aten_div_mode_ops:
38+
return aten_unary
39+
raise RuntimeError(f"div.Tensor_mode not supported for op {op}")
40+
41+
42+
class DecomposeDivTensorModePass(ExportPass):
43+
"""
44+
Rewrites aten.div.Tensor_mode into
45+
46+
rounding_mode=None -> div(a, b)
47+
rounding_mode='floor' -> floor(div(a, b))
48+
rounding_mode='trunc' -> where(div(a,b) < 0, ceil(div(a,b)), floor(div(a,b)))
49+
"""
50+
51+
def call_operator(self, op, args, kwargs, meta):
52+
if op not in (edge_div_mode_ops + aten_div_mode_ops):
53+
return super().call_operator(op, args, kwargs, meta)
54+
55+
opset = _get_opset(op)
56+
57+
a, b = args[0], args[1]
58+
rounding_mode = kwargs.get("rounding_mode", None)
59+
if rounding_mode is None and len(args) > 2:
60+
rounding_mode = args[2]
61+
62+
q = super().call_operator(opset["div"], (a, b), {}, meta)
63+
64+
if rounding_mode is None:
65+
return q
66+
67+
if rounding_mode == "floor":
68+
return super().call_operator(opset["floor"], (q,), {}, meta)
69+
70+
if rounding_mode == "trunc":
71+
zero = super().call_operator(
72+
opset["full"],
73+
args=((1,) * len(meta["val"].size()), 0.0),
74+
kwargs={"dtype": torch.float32},
75+
meta=meta,
76+
)
77+
lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta)
78+
ceilq = self.call_operator(opset["ceil"], (q,), {}, meta)
79+
floorq = self.call_operator(opset["floor"], (q,), {}, meta)
80+
return self.call_operator(opset["where"], (lt0, ceilq, floorq), {}, meta)
81+
82+
raise RuntimeError(
83+
f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}"
84+
)

backends/arm/arm_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
class ArmCompileSpecBuilder:
2424
class DebugMode(Enum):
2525
JSON = 1
26+
TOSA = 2
2627

2728
def __init__(self):
2829
self.compile_spec: List[CompileSpec] = []

backends/arm/debug/schema.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
import json
99

1010
from dataclasses import asdict, dataclass
11-
from typing import Any
11+
from typing import Any, Optional
1212

1313
import serializer.tosa_serializer as ts # type: ignore
1414
import torch
1515

16+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
17+
1618
from torch.fx.traceback import NodeSource
1719

1820

@@ -97,37 +99,52 @@ def from_node(node: torch.fx.Node) -> TorchDebugSchema:
9799
class DebugSchema:
98100
event_id: int
99101
aten_info: ATenDebugSchema
100-
tosa_info: TosaDebugSchema
102+
tosa_info: Optional[TosaDebugSchema]
101103
torch_info: TorchDebugSchema
102104

105+
def to_dict(self) -> dict[str, Any]:
106+
output = asdict(self)
107+
108+
if self.tosa_info is None:
109+
output.pop("tosa_info")
110+
111+
return output
112+
103113

104114
class DebugHook:
105-
def __init__(self) -> None:
115+
def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None:
106116
self._debug_events: list[DebugSchema] = []
107117
self.__op_id_to_name = {}
118+
self.mode = debug_mode
108119

109120
# Build up a mapping from TOSA 1.0 operator IDs to their names
110121
for name, val in vars(ts.Op).items():
111122
self.__op_id_to_name[val] = name
112123

113-
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None:
114-
tosa_debug_info = TosaDebugSchema(
115-
node_name=str(tosa_op),
116-
operator_name=self.__op_id_to_name[tosa_op_id],
117-
operator_id=tosa_op_id,
118-
)
124+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
125+
tosa_debug_info = None
126+
127+
# If the debug data is being embedded into the TOSA flatbuffer
128+
# do not collect TOSADebugSchema data, it's redundent
129+
if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA:
130+
tosa_debug_info = TosaDebugSchema(
131+
node_name=str(tosa_op),
132+
operator_name=self.__op_id_to_name[tosa_op_id],
133+
operator_id=tosa_op_id,
134+
)
119135

120136
aten_debug_info = ATenDebugSchema.from_node(node)
121137
torch_debug_info = TorchDebugSchema.from_node(node)
122138

123-
self._debug_events.append(
124-
DebugSchema(
125-
event_id=len(self._debug_events),
126-
aten_info=aten_debug_info,
127-
tosa_info=tosa_debug_info,
128-
torch_info=torch_debug_info,
129-
)
139+
debug_info = DebugSchema(
140+
event_id=len(self._debug_events),
141+
aten_info=aten_debug_info,
142+
tosa_info=tosa_debug_info,
143+
torch_info=torch_debug_info,
130144
)
145+
self._debug_events.append(debug_info)
146+
147+
return debug_info
131148

132149
def serialize(self) -> str:
133-
return json.dumps([asdict(event) for event in self._debug_events], indent=4)
150+
return json.dumps([event.to_dict() for event in self._debug_events], indent=4)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def is_node_supported(
176176
exir_ops.edge.aten.hardtanh.default,
177177
exir_ops.edge.aten.hardswish.default,
178178
exir_ops.edge.aten.div.Tensor,
179+
exir_ops.edge.aten.div.Tensor_mode,
179180
exir_ops.edge.aten.eq.Tensor,
180181
exir_ops.edge.aten.eq.Scalar,
181182
exir_ops.edge.aten.erf.default,

backends/arm/operators/node_visitor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
# pyre-unsafe
77

8+
import json
89
from typing import Any, Dict, List, Optional
910

1011
import torch
1112

13+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
1214
from executorch.backends.arm.debug.schema import DebugHook
1315
from executorch.backends.arm.tosa.mapping import TosaArg
1416
from executorch.backends.arm.tosa.specification import TosaSpecification
@@ -49,20 +51,25 @@ def _serialize_operator(
4951
outputs: List[str],
5052
attributes: Optional[Any] = None,
5153
) -> None:
54+
op_location = ""
55+
if self.debug_hook:
56+
debug_info = self.debug_hook.add(
57+
node,
58+
tosa_op=outputs[0],
59+
tosa_op_id=tosa_op,
60+
)
61+
62+
if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA:
63+
op_location = json.dumps(debug_info.to_dict())
64+
5265
tosa_graph.addOperator(
5366
tosa_op,
5467
inputs=inputs,
5568
outputs=outputs,
5669
attributes=attributes,
70+
location=op_location,
5771
)
5872

59-
if self.debug_hook:
60-
self.debug_hook.add(
61-
node,
62-
tosa_op=outputs[0],
63-
tosa_op_id=tosa_op,
64-
)
65-
6673
def define_node(
6774
self,
6875
node: torch.fx.Node,

0 commit comments

Comments
 (0)