Skip to content

Commit d9bf112

Browse files
committed
Update
[ghstack-poisoned]
2 parents 8978c19 + b0d965c commit d9bf112

37 files changed

+1623
-204
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
4141
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
4242
from .decompose_linear_pass import DecomposeLinearPass # noqa
43+
from .decompose_masked_fill import DecomposeMaskedFill # noqa
4344
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
4445
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
4546
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
DecomposeLeakyReLUPass,
4646
DecomposeLinearPass,
4747
DecomposeLinearVectorNormPass,
48+
DecomposeMaskedFill,
4849
DecomposeMaxPool2DPass,
4950
DecomposeMeanDimPass,
5051
DecomposeNotEqualPass,
@@ -113,6 +114,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113114
self.add_pass(
114115
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
115116
)
117+
116118
self.add_pass(ConvertFullLikeToFullPass())
117119
self.add_pass(ConvertToClampPass())
118120
self.add_pass(ConvertMinMaxPass())
@@ -146,6 +148,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
146148
self.add_pass(DecomposeMaxPool2DPass())
147149
self.add_pass(SizeAdjustInputPass())
148150
self.add_pass(DecomposeSelectPass())
151+
149152
self.add_pass(ConvertSqueezesToViewPass())
150153

151154
self.add_pass(FuseViewCopyTransform())
@@ -160,6 +163,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
160163
return self._transform(exported_program.graph_module)
161164

162165
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
166+
self.add_pass(DecomposeMaskedFill())
163167
self.add_pass(DecomposeRoundPass())
164168
self.add_pass(DecomposeAcoshPass())
165169
self.add_pass(DecomposeAsinPass())
@@ -285,4 +289,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
285289
self.add_pass(ReplaceInfValues())
286290
self.add_pass(DecomposeSumPass())
287291

292+
if not self.tosa_spec.is_U55_subset:
293+
# Uses where which is not supported on Ethos-U55
294+
self.add_pass(DecomposeMaskedFill())
295+
288296
return self._transform(graph_module)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import torch
10+
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,)
16+
aten_ops = (torch.ops.aten.masked_fill.Scalar,)
17+
18+
19+
def _get_decomposition(op) -> tuple:
20+
if op in edge_ops:
21+
return (
22+
exir_ops.edge.aten.where.self,
23+
exir_ops.edge.aten.full_like.default,
24+
)
25+
if op in aten_ops:
26+
return (
27+
torch.ops.aten.where.self,
28+
torch.ops.aten.full_like.default,
29+
)
30+
raise RuntimeError(f"Unable to get decomposition for op {op}")
31+
32+
33+
class DecomposeMaskedFill(ArmPass):
34+
"""
35+
Masked fill takes in a boolean mask, a tensor and a scalar value.
36+
Fills the tensor with the scalar value according to the boolean mask.
37+
Decomposed to a where and a full_like operator.
38+
"""
39+
40+
def call_operator(self, op, args, kwargs, meta, updated=False):
41+
if op not in (edge_ops + aten_ops):
42+
return super().call_operator(op, args, kwargs, meta, updated)
43+
44+
x, mask, scalar = args
45+
46+
where_op, full_like_op = _get_decomposition(op)
47+
48+
scalar_tensor = super().call_operator(full_like_op, (x, scalar), {}, meta, True)
49+
50+
return super().call_operator(
51+
where_op, (mask, scalar_tensor, x), kwargs, meta, True
52+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def is_node_supported(
254254
exir_ops.edge.aten.asin.default,
255255
exir_ops.edge.aten.atanh.default,
256256
exir_ops.edge.aten.addmm.default,
257+
exir_ops.edge.aten.masked_fill.Scalar,
257258
]
258259

259260
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,6 @@ def any_or_hardtanh_min_zero(n: Node):
500500
elif node.target in [operator.getitem]:
501501
if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type]
502502
return None
503-
504503
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
505504
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
506505
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU85PipelineBI,
14+
OpNotSupportedPipeline,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
20+
aten_op = "torch.aten.ops.masked_fill.Scalar"
21+
exir_op = "executorch_exir_dialects_edge__ops_aten_masked_fill_scalar"
22+
23+
input_t = Tuple[torch.Tensor, torch.Tensor, float]
24+
25+
26+
class MaskedFill(torch.nn.Module):
27+
def forward(
28+
self, x: torch.Tensor, mask: torch.Tensor, value: float
29+
) -> torch.Tensor:
30+
return torch.masked_fill(x, mask, value)
31+
32+
33+
test_modules = {
34+
"masked_fill_1": lambda: (
35+
MaskedFill(),
36+
(
37+
torch.rand(1, 3, 4, 5),
38+
(torch.rand(1, 3, 4, 5) < 0.5), # boolean mask
39+
-1.0,
40+
),
41+
),
42+
"masked_fill_2": lambda: (
43+
MaskedFill(),
44+
(
45+
torch.rand(1, 10, 10, 10),
46+
(torch.rand(1, 10, 10, 10) > 0.75),
47+
3.14,
48+
),
49+
),
50+
"masked_fill_3_zero_fill": lambda: (
51+
MaskedFill(),
52+
(
53+
torch.rand(1, 3, 4, 5),
54+
torch.rand(1, 3, 4, 5) < 0.2,
55+
0.0,
56+
),
57+
),
58+
"masked_fill_4_full_mask": lambda: (
59+
MaskedFill(),
60+
(
61+
torch.rand(1, 3, 4, 5),
62+
torch.ones(1, 3, 4, 5, dtype=torch.bool),
63+
7.0,
64+
),
65+
),
66+
"masked_fill_5_no_mask": lambda: (
67+
MaskedFill(),
68+
(
69+
torch.rand(1, 3, 4, 5),
70+
torch.zeros(1, 3, 4, 5, dtype=torch.bool),
71+
-3.0,
72+
),
73+
),
74+
"masked_fill_6_scalar_broadcast": lambda: (
75+
MaskedFill(),
76+
(
77+
torch.rand(1, 1, 1, 1),
78+
torch.tensor([[[[True]]]]),
79+
42.0,
80+
),
81+
),
82+
"masked_fill_7_large_tensor": lambda: (
83+
MaskedFill(),
84+
(
85+
torch.rand(1, 8, 8, 8),
86+
torch.rand(1, 8, 8, 8) > 0.5,
87+
-127.0,
88+
),
89+
),
90+
"masked_fill_8_extreme_scalar_inf": lambda: (
91+
MaskedFill(),
92+
(
93+
torch.rand(1, 3, 7, 5),
94+
torch.rand(1, 3, 7, 5) > 0.5,
95+
float("inf"),
96+
),
97+
),
98+
}
99+
100+
101+
@common.parametrize("test_module", test_modules)
102+
def test_masked_fill_scalar_tosa_MI(test_module):
103+
module, inputs = test_module()
104+
pipeline = TosaPipelineMI[input_t](module, inputs, aten_op=[])
105+
pipeline.run()
106+
107+
108+
@common.parametrize("test_module", test_modules)
109+
def test_masked_fill_scalar_tosa_BI(test_module):
110+
module, inputs = test_module()
111+
pipeline = TosaPipelineBI[input_t](
112+
module,
113+
inputs,
114+
aten_op=[],
115+
)
116+
pipeline.run()
117+
118+
119+
@common.parametrize("test_module", test_modules)
120+
@common.XfailIfNoCorstone300
121+
def test_masked_fill_scalar_u55_BI(test_module):
122+
module, inputs = test_module()
123+
pipeline = OpNotSupportedPipeline[input_t](
124+
module,
125+
inputs,
126+
{exir_op: 0, "executorch_exir_dialects_edge__ops_aten_where_self": 1},
127+
n_expected_delegates=0,
128+
quantize=True,
129+
u55_subset=True,
130+
)
131+
pipeline.run()
132+
133+
134+
@common.parametrize("test_module", test_modules)
135+
@common.XfailIfNoCorstone320
136+
def test_masked_fill_scalar_u85_BI(test_module):
137+
module, inputs = test_module()
138+
pipeline = EthosU85PipelineBI[input_t](
139+
module,
140+
inputs,
141+
aten_ops=[],
142+
exir_ops=exir_op,
143+
)
144+
pipeline.run()

backends/arm/test/ops/test_multihead_attention.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
EthosU85PipelineBI,
1212
TosaPipelineBI,
1313
TosaPipelineMI,
14+
VgfPipeline,
1415
)
1516

1617

@@ -105,3 +106,39 @@ def test_multihead_attention_u85_BI(test_data: input_t1):
105106
per_channel_quantization=False,
106107
)
107108
pipeline.run()
109+
110+
111+
@common.parametrize(
112+
"test_data",
113+
test_suite,
114+
)
115+
@common.SkipIfNoModelConverter
116+
def test_multihead_attention_vgf_FP(test_data: input_t1):
117+
test_data_vals, module = test_data()
118+
pipeline = VgfPipeline[input_t1](
119+
module,
120+
(*test_data_vals, *test_data_vals, *test_data_vals),
121+
[],
122+
[],
123+
tosa_version="TOSA-1.0+FP",
124+
)
125+
pipeline.run()
126+
127+
128+
@common.parametrize(
129+
"test_data",
130+
test_suite,
131+
)
132+
@common.SkipIfNoModelConverter
133+
def test_multihead_attention_vgf_INT(test_data: input_t1):
134+
test_data_vals, module = test_data()
135+
pipeline = VgfPipeline[input_t1](
136+
module,
137+
(*test_data_vals, *test_data_vals, *test_data_vals),
138+
[],
139+
[],
140+
tosa_version="TOSA-1.0+INT",
141+
# TODO: Per-channel quantization is broken (MLETORCH-1144)
142+
per_channel_quantization=False,
143+
)
144+
pipeline.run()

backends/arm/test/tester/test_pipeline.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def __init__(
854854
vgf_compiler_flags: Optional[str] = "",
855855
tosa_version: str = "TOSA-1.0+FP",
856856
symmetric_io_quantization: bool = False,
857-
per_channel_quantization: bool = False,
857+
per_channel_quantization: bool = True,
858858
use_to_edge_transform_and_lower: bool = True,
859859
custom_path: str = None,
860860
atol: float = 1e-03,
@@ -866,11 +866,6 @@ def __init__(
866866
] = None,
867867
):
868868

869-
if (
870-
symmetric_io_quantization or per_channel_quantization
871-
) and tosa_version == "TOSA-1.0+FP":
872-
raise ValueError("Dont configure quantization with FP TOSA profile.")
873-
874869
tosa_profile = TosaSpecification.create_from_string(tosa_version)
875870
compile_spec = common.get_vgf_compile_spec(
876871
tosa_profile, compiler_flags=vgf_compiler_flags, custom_path=custom_path
@@ -887,18 +882,15 @@ def __init__(
887882
transform_passes=transform_passes,
888883
)
889884

890-
if symmetric_io_quantization or per_channel_quantization:
885+
if "INT" in tosa_version:
891886
quantizer = VgfQuantizer(compile_spec)
892887
quantization_config = get_symmetric_quantization_config(
893888
is_per_channel=per_channel_quantization
894889
)
895890
if symmetric_io_quantization:
896891
quantizer.set_io(quantization_config)
897892
quant_stage = Quantize(quantizer, quantization_config)
898-
else:
899-
quant_stage = None
900893

901-
if "INT" in tosa_version:
902894
self.add_stage(self.tester.quantize, quant_stage, pos=0)
903895

904896
self.add_stage_after(

backends/cadence/aot/compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def trace(
5959
dump_graphs: bool = False,
6060
) -> ExportedProgram:
6161
"""
62-
Trace the model with export_for_training and return an ExportedProgram.
62+
Trace the model with export and return an ExportedProgram.
6363
"""
6464

6565
# Make the model inference mode by calling model.eval()
@@ -83,9 +83,9 @@ def trace(
8383
remove_decompositions(decomp_table, ops_to_keep)
8484

8585
# Export with dynamo
86-
program = torch.export.export_for_training(
87-
model, inputs, strict=True
88-
).run_decompositions(decomp_table)
86+
program = torch.export.export(model, inputs, strict=True).run_decompositions(
87+
decomp_table
88+
)
8989

9090
if dump_graphs:
9191
logging.info("Graph before quantization:")

backends/nxp/backend/edge_program_converter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,20 @@
2323

2424
# noinspection PyProtectedMember
2525
functions_converters = {
26+
exir_ops.edge.aten.abs.default: AbsConverter, # noqa F405
27+
exir_ops.edge.aten._adaptive_avg_pool2d.default: AdaptiveAvgPool2dConverter, # noqa F405
2628
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
29+
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
2730
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
31+
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
2832
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
2933
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
34+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
3035
exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, # noqa F405
36+
exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405
3137
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
3238
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
3339
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
34-
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
3540
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
3641
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
3742
}

0 commit comments

Comments
 (0)