Skip to content

Commit 2d8919f

Browse files
authored
Merge branch 'main' into an/openvino/quantizer_fix
2 parents 43f4b1f + 7fa93a7 commit 2d8919f

22 files changed

+579
-44
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .decompose_int16_activation_conv2d_pass import ( # noqa
5353
DecomposeConv2dWithInt16ActivationPass,
5454
)
55+
from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa
5556
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
5657
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5758
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
DecomposeGluPass,
5656
DecomposeGroupedConvPass,
5757
DecomposeGroupNormPass,
58+
DecomposeInt32ClampPass,
5859
DecomposeIntPowPass,
5960
DecomposeLayerNormPass,
6061
DecomposeLeakyReLUPass,
@@ -122,7 +123,6 @@
122123

123124

124125
class ArmPassManager(PassManager):
125-
126126
def __init__(self, tosa_spec: TosaSpecification) -> None:
127127
self.tosa_spec = tosa_spec
128128
super().__init__()
@@ -174,6 +174,7 @@ def _tosa_pipeline(
174174
FuseQuantizedActivationPass(),
175175
RemoveGetItemPass(),
176176
ConvertToClampPass(),
177+
DecomposeInt32ClampPass(),
177178
DecomposeGroupNormPass(),
178179
DecomposeLayerNormPass(),
179180
DecomposeBatchNormNoStatsPass(),
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass
12+
13+
14+
class DecomposeInt32ClampPass(ArmPass):
15+
"""Rewrite int32 clamp into min/max chain since TOSA lacks int32 clamp support."""
16+
17+
_passes_required_after: Set[Type[ExportPass]] = set()
18+
_supported_ops = {
19+
exir_ops.edge.aten.clamp.default,
20+
torch.ops.aten.clamp.default,
21+
}
22+
23+
def _ensure_tensor(
24+
self,
25+
value,
26+
ref_tensor,
27+
dtype,
28+
rank,
29+
meta,
30+
):
31+
if value is None:
32+
return None
33+
return super().call_operator(
34+
exir_ops.edge.aten.full.default,
35+
((1,) * rank, value),
36+
{"dtype": dtype},
37+
meta,
38+
updated=True,
39+
)
40+
41+
def call_operator(self, op, args, kwargs, meta):
42+
val = meta["val"]
43+
if op not in self._supported_ops or val.dtype != torch.int32:
44+
return super().call_operator(op, args, kwargs, meta)
45+
46+
input_tensor = args[0]
47+
min_arg = args[1] if len(args) > 1 else None
48+
max_arg = args[2] if len(args) > 2 else None
49+
dtype = val.dtype
50+
rank = len(val.shape)
51+
52+
min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta)
53+
max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta)
54+
55+
current = input_tensor
56+
if max_arg is not None:
57+
current = super().call_operator(
58+
exir_ops.edge.aten.minimum.default,
59+
(current, max_arg),
60+
{},
61+
meta,
62+
updated=True,
63+
)
64+
if min_arg is not None:
65+
current = super().call_operator(
66+
exir_ops.edge.aten.maximum.default,
67+
(current, min_arg),
68+
{},
69+
meta,
70+
updated=True,
71+
)
72+
return current

backends/arm/operators/node_visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _serialize_operator(
8787
None: Mutates ``tosa_graph`` in place.
8888
8989
"""
90-
op_location = ts.TosaOpLocation()
90+
op_location = None
9191
if self.debug_hook:
9292
debug_info = self.debug_hook.add(
9393
node,
@@ -96,7 +96,7 @@ def _serialize_operator(
9696
)
9797

9898
if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA:
99-
op_location.text = json.dumps(debug_info.to_dict())
99+
op_location = json.dumps(debug_info.to_dict())
100100

101101
tosa_graph.addOperator(
102102
tosa_op,

backends/arm/operators/op_clamp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(self, *args):
4040
def _get_min_max_arguments(
4141
self, node: Node, dtype: torch.dtype
4242
) -> Tuple[int | float, int | float]:
43-
4443
def cast_type(value: Any) -> int | float:
4544
if isinstance(value, int):
4645
return value
@@ -91,7 +90,12 @@ def define_node(
9190
validate_valid_dtype(
9291
self.target,
9392
[inputs[0], output],
94-
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32],
93+
[
94+
ts.DType.INT8,
95+
ts.DType.INT16,
96+
ts.DType.FP16,
97+
ts.DType.FP32,
98+
],
9599
output.tosa_spec,
96100
)
97101

backends/arm/test/misc/test_debug_feats.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
1818
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.runner_utils import dbg_tosa_fb_to_json
1920
from executorch.backends.arm.test.tester.test_pipeline import (
2021
EthosU55PipelineINT,
2122
TosaPipelineFP,
@@ -238,25 +239,47 @@ def test_dump_tosa_debug_json(test_data: input_t1):
238239

239240
@common.parametrize("test_data", Linear.inputs)
240241
def test_dump_tosa_debug_tosa(test_data: input_t1):
241-
with tempfile.TemporaryDirectory() as tmpdir:
242-
aten_ops: list[str] = []
243-
exir_ops: list[str] = []
244-
pipeline = TosaPipelineINT[input_t1](
245-
module=Linear(),
246-
test_data=test_data,
247-
aten_op=aten_ops,
248-
exir_op=exir_ops,
249-
custom_path=tmpdir,
250-
tosa_debug_mode=ArmCompileSpec.DebugMode.TOSA,
251-
)
242+
output_dir = "test_dump_tosa_debug"
252243

253-
pipeline.pop_stage("run_method_and_compare_outputs")
254-
pipeline.run()
244+
aten_ops: list[str] = []
245+
exir_ops: list[str] = []
246+
pipeline = TosaPipelineFP[input_t1](
247+
module=Linear(),
248+
test_data=test_data,
249+
use_to_edge_transform_and_lower=True,
250+
aten_op=aten_ops,
251+
exir_op=exir_ops,
252+
custom_path=output_dir,
253+
tosa_debug_mode=ArmCompileSpec.DebugMode.TOSA,
254+
)
255255

256-
json_output_path = Path(tmpdir) / "debug.json"
256+
pipeline.pop_stage("run_method_and_compare_outputs")
257+
pipeline.run()
258+
259+
output_path = Path(output_dir)
260+
json_output_path = output_path / "debug.json"
261+
262+
# A JSON file should not be created when TOSA mode used
263+
assert not json_output_path.exists()
264+
265+
# At least one TOSA file should exist
266+
tosa_files = list(output_path.glob("*.tosa"))
267+
assert len(tosa_files) > 0
268+
269+
tosa_file = tosa_files[0]
270+
with tosa_file.open("rb") as f:
271+
tosa_json = dbg_tosa_fb_to_json(f.read())
272+
273+
# Check all non-empty JSON strings are valid
274+
ops = tosa_json["regions"][0]["blocks"][0]["operators"]
275+
for op in ops:
276+
if op["location"]["text"]:
277+
try:
278+
json.loads(op["location"]["text"])
279+
except json.JSONDecodeError:
280+
pytest.fail("Failed to load debug JSON string")
257281

258-
# A JSON file should not be created when TOSA mode used
259-
assert not json_output_path.exists()
282+
shutil.rmtree(output_dir, ignore_errors=True)
260283

261284

262285
@common.parametrize("test_data", Linear.inputs)

backends/arm/test/ops/test_clamp.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,25 @@
3535
"rank_4_no_max": lambda: (torch.rand(1, 10, 10, 1) - 3, -3.3, None),
3636
}
3737

38+
test_data_suite_int32 = {
39+
"int32_rank2": lambda: (torch.randint(-50, 50, (2, 3), dtype=torch.int32), -10, 10),
40+
"int32_rank3_no_min": lambda: (
41+
torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32),
42+
None,
43+
25,
44+
),
45+
"int32_rank3_no_max": lambda: (
46+
torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32),
47+
-25,
48+
None,
49+
),
50+
"int32_rank4_large_range": lambda: (
51+
torch.randint(-200, 200, (1, 2, 4, 4), dtype=torch.int32),
52+
torch.iinfo(torch.int32).min,
53+
torch.iinfo(torch.int32).max,
54+
),
55+
}
56+
3857

3958
class Clamp(torch.nn.Module):
4059
def __init__(
@@ -53,7 +72,6 @@ def forward(self, x):
5372

5473
@common.parametrize("test_data", test_data_suite)
5574
def test_clamp_tosa_FP(test_data):
56-
5775
input_tensor, min_val, max_val = test_data()
5876
model = Clamp(min_val, max_val)
5977

@@ -69,7 +87,6 @@ def test_clamp_tosa_FP(test_data):
6987

7088
@common.parametrize("test_data", test_data_suite)
7189
def test_clamp_tosa_INT(test_data):
72-
7390
input_tensor, min_val, max_val = test_data()
7491
model = Clamp(min_val, max_val)
7592

@@ -84,6 +101,22 @@ def test_clamp_tosa_INT(test_data):
84101
pipeline.run()
85102

86103

104+
@common.parametrize("test_data", test_data_suite_int32)
105+
def test_clamp_tosa_INT_int32_inputs(test_data):
106+
input_tensor, min_val, max_val = test_data()
107+
model = Clamp(min_val, max_val)
108+
109+
pipeline = TosaPipelineINT[input_t](
110+
model,
111+
(input_tensor,),
112+
aten_op,
113+
exir_op,
114+
)
115+
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
116+
pipeline.pop_stage("quantize")
117+
pipeline.run()
118+
119+
87120
@common.parametrize("test_data", test_data_suite)
88121
def test_clamp_tosa_INT_a16w8(test_data):
89122
"""Test clamp operation with int16 I/O quantization for TOSA INT."""
@@ -103,7 +136,6 @@ def test_clamp_tosa_INT_a16w8(test_data):
103136
@common.parametrize("test_data", test_data_suite)
104137
@common.XfailIfNoCorstone300
105138
def test_clamp_u55_INT(test_data):
106-
107139
input_tensor, min_val, max_val = test_data()
108140
model = Clamp(min_val, max_val)
109141

@@ -140,7 +172,6 @@ def test_clamp_16a8w_u55_INT16(test_data):
140172
@common.parametrize("test_data", test_data_suite)
141173
@common.XfailIfNoCorstone320
142174
def test_clamp_u85_INT(test_data):
143-
144175
input_tensor, min_val, max_val = test_data()
145176
model = Clamp(min_val, max_val)
146177

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.decompose_int32_clamp_pass import (
10+
DecomposeInt32ClampPass,
11+
)
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor]
16+
17+
18+
class ClampInt32(torch.nn.Module):
19+
test_data = {"rand": (torch.randint(-50, 50, (2, 3), dtype=torch.int32),)}
20+
21+
def forward(self, x: torch.Tensor):
22+
return torch.clamp(x, -10, 5)
23+
24+
25+
@common.parametrize("test_data", ClampInt32.test_data)
26+
def test_decompose_int32_clamp_pass(test_data: input_t):
27+
module = ClampInt32()
28+
pipeline = PassPipeline[input_t](
29+
module,
30+
test_data,
31+
quantize=False,
32+
ops_before_pass={
33+
"executorch_exir_dialects_edge__ops_aten_clamp_default": 1,
34+
},
35+
ops_after_pass={
36+
"executorch_exir_dialects_edge__ops_aten_minimum_default": 1,
37+
"executorch_exir_dialects_edge__ops_aten_maximum_default": 1,
38+
},
39+
ops_not_after_pass=[
40+
"executorch_exir_dialects_edge__ops_aten_clamp_default",
41+
],
42+
pass_list=[DecomposeInt32ClampPass],
43+
)
44+
pipeline.run()

backends/arm/tosa/schemas/tosa_1.0.fbs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,13 +510,31 @@ table TosaTensor {
510510
variable: bool; // is this a variable tensor
511511
is_unranked: bool; // whether this is an unranked tensor
512512
variable_name:string; // name for variable attribute
513+
514+
// In a model that is larger than 2GB, then tensors instead uses the following
515+
// attributes to find stored data, which is outside of flatbuffers
516+
// the offset is calculated relative to the beginning of the file and is only
517+
// valid if > 1.
518+
offset: ulong;
519+
size: ulong;
520+
}
521+
522+
table TosaShape {
523+
name: string; // name of the shape
524+
rank: uint32; // rank of the shape
525+
data: [ubyte] (force_align: 8); // raw data array if it's a constant shape
526+
}
527+
528+
table OpLocation {
529+
text: string; // Opaque string, interpretted by user
513530
}
514531

515532
table TosaOperator {
516533
op:Op; // operator enum
517534
attribute:Attribute; // union structure. operator attribute
518-
inputs:[string]; // list of input tensor names
519-
outputs:[string]; // list of output tensor names
535+
inputs:[string]; // list of input tensor or shape names
536+
outputs:[string]; // list of output tensor or shape names
537+
location: OpLocation; // location of this Op in mlir
520538
}
521539

522540
table TosaBasicBlock {
@@ -525,6 +543,7 @@ table TosaBasicBlock {
525543
tensors:[TosaTensor]; // tensors array
526544
inputs:[string]; // name of graph inputs
527545
outputs:[string]; // name of graph outputs
546+
shapes:[TosaShape]; // shapes array
528547
}
529548

530549
table TosaRegion {
@@ -537,4 +556,4 @@ table TosaGraph {
537556
regions:[TosaRegion]; // regions array
538557
}
539558

540-
root_type TosaGraph;
559+
root_type TosaGraph;

0 commit comments

Comments
 (0)