diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index e9f6eec63a3..541638b830e 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -38,17 +38,17 @@ def rescale_fake( """Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op. Additionally validates TOSA constraints of a RESCALE op. """ - if not (dtype == torch.int32 or dtype == torch.int8): + if dtype not in (torch.int32, torch.int8, torch.int16): raise NotImplementedError( - "tosa::rescale currently only supports int32 and int8." + f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}" ) - if dtype == torch.int32 and out_zp != 0: + if dtype in (torch.int32, torch.int16) and out_zp != 0: raise ValueError( - "TOSA requires output_zp to be zero when the output dtype is int32." + f"TOSA requires output_zp to be zero when the output dtype is {dtype}." ) - if x.dtype == torch.int32 and in_zp != 0: + if x.dtype in (torch.int32, torch.int16) and in_zp != 0: raise ValueError( - "TOSA requires input_zp to be zero when the input dtype is int32." + f"TOSA requires input_zp to be zero when the input dtype is {dtype}" ) if x.dtype == torch.int8 and not -128 <= in_zp <= 127: raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.") diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 77de46fcd29..b3f5c9622bb 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,6 +17,7 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule + from torch.library import impl, Library lib = Library("tosa", "DEF") @@ -26,7 +26,10 @@ @impl(lib, "_table") def _table_impl(*args, **kwargs): # pyre-ignore - return args[0] + in_dtype = args[0].dtype + if in_dtype == torch.int8: + return args[0] + return args[0].to(dtype=torch.int32) class InsertTableOpsPass(ExportPass): @@ -59,29 +62,105 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None: """ self.exported_program.state_dict[buffer_name] = buffer - def generate_table_values( + def generate_8bit_table_values( self, torch_op: Callable[[torch.Tensor], torch.Tensor], in_quantargs: QuantArgs, out_quantargs: QuantArgs, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, int]: + """Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table. + The INT8 table is a simple 256 value 1-1 LUT. + """ + def f(x: torch.Tensor) -> torch.Tensor: x = in_quantargs.dequantize_value(x) x = torch_op(x) return out_quantargs.quantize_value(x) - input_dtype = in_quantargs.dtype - steps = in_quantargs.qmax - in_quantargs.qmin + 1 - return f( + return ( + f( + torch.linspace( + start=in_quantargs.qmin, + end=in_quantargs.qmax, + steps=256, + # use torch.int64 to avoid overflow when dequantizing (subtracting zp). + # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8) + dtype=torch.int64, + ) + ).to(dtype=torch.int8), + 0, + ) + + def generate_16_bit_table_values( + self, + torch_op: Callable[[torch.Tensor], torch.Tensor], + in_quantargs: QuantArgs, + out_quantargs: QuantArgs, + ) -> tuple[torch.Tensor, int]: + """Compute LUT values for a INT16 TOSA.TABLE with 32 bit output. + In practice the output is 23 bits that should be interpreted as 16 'whole' bits and 7 fractional bits, see + the specification: https://www.mlplatform.org/tosa/tosa_spec.html#_table. This means that the output + will interpreted as 2**7=128 times too large unless accounted for by rescaling down the table output. + + Quantization can be either int16 or int32 which means that the op output could be larger than the 23 bits from + the TOSA.TABLE output. In that case, we need to rescale up the output. + + To handle this we need to: + 1) Make sure that our table values fit within 16 bits. + 2) Insert a rescale after the table to handle the x128 from the fractional bits and match the quantization. + + The function returns rescale_lshift which says how much to rescale after the table. This value can negative. + """ + + def f(x: torch.Tensor) -> torch.Tensor: + # Dont use the 7 LSBs. + x = in_quantargs.dequantize_value((x & ~0x7F)) + x = torch_op(x) + return out_quantargs.quantize_value(x) + + lut_values = f( torch.linspace( start=in_quantargs.qmin, - end=in_quantargs.qmax, - steps=steps, + end=in_quantargs.qmax + 1, + steps=513, # use torch.int64 to avoid overflow when dequantizing (subtracting zp). # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8) dtype=torch.int64, ) - ).to(dtype=input_dtype) + ) + # Calculate how much we need to shift table values to fit in 16 signed bits + # ceil(log2(max absolute table value)) + 1 bit for signedness - 16 + # Example: + # Max value in the table is 70 000. We want to fit it in 16 signed bits. + # 70 000=0b10001000101110000 (17 digits) has ceil(log2(70 000)) = ceil(16.095) = 17 bits. + # If we shift it 17-16=1 bit, we do get 16 bits (0b1000100010111000), + # but due to signedness this is a negative number! So we need to shift it one more bit. + # Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7. + rshift = int(torch.ceil(torch.log2(lut_values.abs().max()))) + 1 - 16 + # The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do. + rescale_lshift = rshift - 7 + lut_values = lut_values >> rshift + return lut_values.to(dtype=torch.int16), rescale_lshift + + def generate_table_values( + self, + torch_op: Callable[[torch.Tensor], torch.Tensor], + in_quantargs: QuantArgs, + out_quantargs: QuantArgs, + ) -> tuple[torch.Tensor, int]: + match out_quantargs.dtype: + case torch.int8: + return self.generate_8bit_table_values( + torch_op, in_quantargs, out_quantargs + ) + case torch.int16 | torch.int32: + return self.generate_16_bit_table_values( + torch_op, in_quantargs, out_quantargs + ) + case _: + raise ValueError( + f"Unsupported output dtype for table: {out_quantargs.dtype}" + ) def call(self, graph_module: GraphModule) -> PassResult: modified = False @@ -100,10 +179,12 @@ def call(self, graph_module: GraphModule) -> PassResult: op_target=torch.ops.tosa._table.default, args=(node.args[0],), ) + output_node = table_node assert len(input_qparams) == 1 assert len(output_qparams) == 1 - # Generate table buffer - buffer = self.generate_table_values( + + # Generate table buffer and how much to lshift the table output. + buffer, lshift = self.generate_table_values( torch_op=self.table_ops[node.target], in_quantargs=input_qparams[0], out_quantargs=output_qparams[0], @@ -114,10 +195,20 @@ def call(self, graph_module: GraphModule) -> PassResult: self.register_buffer( buffer_name=table_node.name.replace("_default", ""), buffer=buffer ) - node.replace_all_uses_with(table_node) + + if lshift != 0: + scale = 2.0**lshift + rescale_node = create_node( + graph=graph_module.graph, + op_target=torch.ops.tosa._rescale.default, + args=(table_node, output_qparams[0].dtype, scale, 0, 0), + ) + output_node = rescale_node + + node.replace_all_uses_with(output_node) graph_module.graph.erase_node(node) - table_node.meta["input_qparams"] = input_qparams - table_node.meta["output_qparams"] = output_qparams + output_node.meta["input_qparams"] = input_qparams + output_node.meta["output_qparams"] = output_qparams modified = True if modified: diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index afb5f93baa7..f2c7ce9f9ce 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -30,7 +30,7 @@ class NodeVisitor: ] def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification): - self._exported_program = exported_program or None + self._exported_program = exported_program self.tosa_spec = tosa_spec def define_node( diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index 5864495fc05..098fbeccce1 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -38,7 +38,6 @@ def define_node( input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) - # Skip int16 cases for now. if input_dtype != map_dtype(torch.int8) and input_zp != 0: raise ValueError( f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}" @@ -48,7 +47,10 @@ def define_node( f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}" ) - scale_width = 32 if output_dtype == torch.int32 else 16 + # scale32 gives higher accuracy but for a higher HW cost. + # For now, always go for scale32. + scale_32 = True + scale_width = 32 if scale_32 else 16 multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift( [scale], scale_width ) @@ -58,7 +60,7 @@ def define_node( output_zp=output_zp, multiplier=multiplier, shift=shift, - scale32=output_dtype == torch.int32, + scale32=scale_32, double_round=False, per_channel=False, input_unsigned=False, diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index b411d8b91ba..da7e2e8be95 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -30,11 +30,24 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - assert node.name in self._exported_program.state_dict.keys() # type: ignore[union-attr] - assert inputs[0].dtype == output.dtype == ts.DType.INT8 + if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] + raise RuntimeError( + f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." + ) + if inputs[0].dtype == ts.DType.INT8 and output.dtype != ts.DType.INT8: + raise ValueError(f"Int8 tables need int8 output, got {output.dtype=}.") + if inputs[0].dtype == ts.DType.INT16 and output.dtype != ts.DType.INT32: + raise ValueError(f"Int16 tables need int32 output, got {output.dtype=}.") + + if inputs[0].dtype not in (ts.DType.INT8, ts.DType.INT16): + raise ValueError( + f"TOSA.TABLE only supports int8 or int16 inputs, got {ts.DTypeNames[inputs[0]]}" + ) + table = self._exported_program.state_dict[node.name] # type: ignore[union-attr] table_attr = ts.TosaSerializerAttribute() table_attr.TableAttribute(np.array(table)) + tosa_graph.addOperator( TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr ) diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py new file mode 100644 index 00000000000..a1d0622b32f --- /dev/null +++ b/backends/arm/test/ops/test_sigmoid_16bit.py @@ -0,0 +1,173 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, +) +from executorch.backends.xnnpack.test.tester import Quantize +from torch.ao.quantization.observer import HistogramObserver +from torch.ao.quantization.quantizer import QuantizationSpec + + +def _get_16_bit_quant_config(): + int16_spec = QuantizationSpec( + dtype=torch.int16, + observer_or_fake_quant_ctr=HistogramObserver, + qscheme=torch.per_tensor_symmetric, + ) + qconfig = QuantizationConfig( + input_activation=int16_spec, + output_activation=int16_spec, + weight=None, + bias=None, + ) + return qconfig + + +def get_16bit_sigmoid_quantizer(tosa_str: str): + tosa_spec = common.TosaSpecification.create_from_string(tosa_str) + quantizer = TOSAQuantizer(tosa_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_module_type( + torch.nn.modules.activation.Sigmoid, _get_16_bit_quant_config() + ) + + return Quantize(quantizer, get_symmetric_quantization_config()) + + +input_t = tuple[torch.Tensor] +test_data_suite = { + "ones": (torch.ones(10, 10, 10),), + "rand": (torch.rand(10, 10) - 0.5,), + "rand_4d": (torch.rand(1, 1, 5, 10),), + "randn_pos": (torch.randn(10) + 10,), + "randn_neg": (torch.randn(10) - 10,), + "ramp": (torch.arange(-16, 16, 0.02),), +} + + +class Sigmoid(torch.nn.Module): + aten_op = "torch.ops.aten.sigmoid.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default" + + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + return self.sigmoid(x) + + +class SigmoidAddSigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + return self.sigmoid((self.sigmoid(x) + self.sigmoid(x))) + + +@common.parametrize("test_data", test_data_suite) +def test_sigmoid_tosa_BI(test_data): + pipeline = TosaPipelineBI(Sigmoid(), test_data, Sigmoid.aten_op, Sigmoid.exir_op) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "ramp": "AssertionError: Output 0 does not match reference output. Passes with qtol=2. MLETORCH-787" + }, +) +def test_sigmoid_add_sigmoid_tosa_BI(test_data): + pipeline = TosaPipelineBI( + SigmoidAddSigmoid(), test_data, Sigmoid.aten_op, Sigmoid.exir_op + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + }, +) +@common.XfailIfNoCorstone300 +def test_sigmoid_tosa_u55(test_data): + pipeline = EthosU55PipelineBI( + Sigmoid(), test_data, Sigmoid.aten_op, Sigmoid.exir_op, run_on_fvp=True + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + }, +) +@common.XfailIfNoCorstone300 +def test_sigmoid_add_sigmoid_tosa_u55(test_data): + pipeline = EthosU55PipelineBI( + SigmoidAddSigmoid(), + test_data, + Sigmoid.aten_op, + Sigmoid.exir_op, + run_on_fvp=True, + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_sigmoid_tosa_u85(test_data): + pipeline = EthosU85PipelineBI( + Sigmoid(), test_data, Sigmoid.aten_op, Sigmoid.exir_op, run_on_fvp=True + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "ramp": "AssertionError: Output 0 does not match reference output.", + }, +) +@common.XfailIfNoCorstone320 +def test_sigmoid_add_sigmoid_tosa_u85(test_data): + pipeline = EthosU85PipelineBI( + SigmoidAddSigmoid(), + test_data, + Sigmoid.aten_op, + Sigmoid.exir_op, + run_on_fvp=True, + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.run() diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py new file mode 100644 index 00000000000..0a7bcec1043 --- /dev/null +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -0,0 +1,200 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import TOSAQuantizer +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, +) +from executorch.backends.xnnpack.test.tester import Quantize +from torch.ao.quantization.observer import HistogramObserver +from torch.ao.quantization.quantizer import QuantizationSpec + + +def _get_16_bit_quant_config(): + int16_spec = QuantizationSpec( + dtype=torch.int16, + observer_or_fake_quant_ctr=HistogramObserver, + qscheme=torch.per_tensor_symmetric, + ) + int32_spec = QuantizationSpec( + dtype=torch.int32, + observer_or_fake_quant_ctr=HistogramObserver, + qscheme=torch.per_tensor_symmetric, + ) + qconfig = QuantizationConfig( + input_activation=int16_spec, + output_activation=int32_spec, + weight=None, + bias=None, + ) + return qconfig + + +def _get_32_bit_quant_config(): + int32_spec = QuantizationSpec( + dtype=torch.int32, + observer_or_fake_quant_ctr=HistogramObserver, + qscheme=torch.per_tensor_symmetric, + ) + qconfig = QuantizationConfig( + input_activation=int32_spec, + output_activation=int32_spec, + weight=None, + bias=None, + ) + return qconfig + + +def get_16bit_sigmoid_quantizer(tosa_str: str): + tosa_spec = common.TosaSpecification.create_from_string(tosa_str) + quantizer = TOSAQuantizer(tosa_spec) + quantizer.set_global(_get_32_bit_quant_config()) + quantizer.set_module_type( + torch.nn.modules.activation.Sigmoid, _get_16_bit_quant_config() + ) + + return Quantize(quantizer, _get_32_bit_quant_config()) + + +input_t = tuple[torch.Tensor] +test_data_suite = { + "ones": (torch.ones(10, 10, 10),), + "rand": (torch.rand(10, 10) - 0.5,), + "rand_4d": (torch.rand(1, 10, 10, 10),), + "randn_pos": (torch.randn(10) + 10,), + "randn_neg": (torch.randn(10) - 10,), + "ramp": (torch.arange(-16, 16, 0.2),), +} + + +class Sigmoid(torch.nn.Module): + aten_op = "torch.ops.aten.sigmoid.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default" + + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + return self.sigmoid(x) + + +class SigmoidAddSigmoid(torch.nn.Module): + def __init__(self): + super().__init__() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + return self.sigmoid((self.sigmoid(x) + self.sigmoid(x))) + + +@common.parametrize("test_data", test_data_suite) +def test_sigmoid_tosa_BI(test_data): + pipeline = TosaPipelineBI( + Sigmoid(), + test_data, + Sigmoid.aten_op, + Sigmoid.exir_op, + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_sigmoid_add_sigmoid_tosa_BI(test_data): + pipeline = TosaPipelineBI( + SigmoidAddSigmoid(), + test_data, + Sigmoid.aten_op, + Sigmoid.exir_op, + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.change_args("run_method_and_compare_outputs", test_data, qtol=1) + + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + }, +) +@common.XfailIfNoCorstone300 +def test_sigmoid_tosa_u55(test_data): + pipeline = EthosU55PipelineBI( + Sigmoid(), test_data, Sigmoid.aten_op, Sigmoid.exir_op, run_on_fvp=True + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) + pipeline.change_args("run_method_and_compare_outputs", test_data, qtol=1) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + "ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", + }, +) +@common.XfailIfNoCorstone300 +def test_sigmoid_add_sigmoid_tosa_u55(test_data): + pipeline = EthosU55PipelineBI( + SigmoidAddSigmoid(), + test_data, + Sigmoid.aten_op, + Sigmoid.exir_op, + run_on_fvp=True, + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) + pipeline.change_args("run_method_and_compare_outputs", test_data, qtol=1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_sigmoid_tosa_u85(test_data): + pipeline = EthosU85PipelineBI( + Sigmoid(), test_data, Sigmoid.aten_op, Sigmoid.exir_op, run_on_fvp=True + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.change_args("run_method_and_compare_outputs", test_data, qtol=1) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "ramp": "AssertionError: Output 0 does not match reference output.", + }, +) +@common.XfailIfNoCorstone320 +def test_sigmoid_add_sigmoid_tosa_u85(test_data): + pipeline = EthosU85PipelineBI( + SigmoidAddSigmoid(), + test_data, + Sigmoid.aten_op, + Sigmoid.exir_op, + run_on_fvp=True, + ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI")) + pipeline.change_args("run_method_and_compare_outputs", test_data, qtol=1) + pipeline.run()