diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 54ea502275a..aac3b300f9b 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -57,4 +57,5 @@ from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa +from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip from .arm_pass_manager import ArmPassManager # noqa # usort: skip diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 210f006d72d..7ce07d5e73f 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -49,6 +49,7 @@ MatchWhereSelfDtypePass, QuantizeOperatorArguments, RemoveClonePass, + ReplaceInfValues, ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, RetraceFoldedDtypesPass, @@ -216,4 +217,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeSoftmaxPass()) self.add_pass(ConvertMinMaxPass()) + self.add_pass(ReplaceInfValues()) return self._transform(graph_module) diff --git a/backends/arm/_passes/replace_inf_values_pass.py b/backends/arm/_passes/replace_inf_values_pass.py new file mode 100644 index 00000000000..8c721eda3d8 --- /dev/null +++ b/backends/arm/_passes/replace_inf_values_pass.py @@ -0,0 +1,45 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This pass is based on backends/qualcomm/_passes/replace_inf_values.py +# with some modification to replaced inf values. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class ReplaceInfValues(ExportPass): + """ + Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values. + """ + + def __init__(self): + super(ReplaceInfValues, self).__init__() + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for buf_name, tensor in graph_module.named_buffers(): + if tensor.is_floating_point(): + modified = True + # 255 here is mainly for attention_mask in Llama for reasonable quant scale + tensor[tensor == float("inf")] = 255 + tensor[tensor == float("-inf")] = -255 + setattr(graph_module, buf_name, tensor) + + for node in graph_module.graph.nodes: + arg_list = list(node.args) + for index, arg in enumerate(arg_list): + if arg == float("-inf"): + modified = True + arg_list[index] = -255 + elif arg == float("inf"): + modified = True + arg_list[index] = +255 + node.args = tuple(arg_list) + + if modified: + graph_module.recompile() + return PassResult(graph_module, modified) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5ac747177be..aa41a9844d6 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -408,6 +408,9 @@ def any_or_hardtanh_min_zero(n: Node): shared_qspec = SharedQuantizationSpec(node.args[0]) quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type] quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + elif node.target in [torch.ops.aten.scalar_tensor.default]: + quant_properties.quant_inputs = [] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) else: return None @@ -455,5 +458,6 @@ def annotate_graph( # type: ignore[return] if node.target in [ torch.ops.aten.full_like.default, torch.ops.aten.full.default, + torch.ops.aten.scalar_tensor.default, ]: node.kwargs = {} diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index 644ad69222c..44a8fdc2a04 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -105,7 +105,6 @@ def test_llama_tosa_MI(self): ) ) - @pytest.mark.xfail(reason="KeyError: scalar_tensor_1 (MLETORCH-907)") def test_llama_tosa_BI(self): llama_model, llama_inputs, llama_meta = self.prepare_model() @@ -126,7 +125,7 @@ def test_llama_tosa_BI(self): .to_executorch() .run_method_and_compare_outputs( inputs=llama_inputs, - atol=4.3, - rtol=1.1, # TODO: Tolerance needs to be updated after MLETORCH-907 + atol=9.9, + rtol=1.5, # TODO: Tolerance needs to be updated after MLETORCH-907 ) )