diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index 5d82810f0d7..c1509b4feae 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -164,7 +164,7 @@ def define_node( scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) # type: ignore[possibly-undefined] else: # input[0].dtype == ts.DType.INT32 @@ -192,7 +192,7 @@ def define_node( # Scale output back to 8 bit # pyre-ignore tqutils.insert_rescale_op_to_int8( - tosa_graph, abs_output, scale_back, node, self.tosa_specs + tosa_graph, abs_output, scale_back, node, self.tosa_spec ) # type: ignore[possibly-undefined] diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index b8e3d1561ca..9b981f23710 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -174,7 +174,7 @@ def define_node( scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) else: # input[0].dtype == ts.DType.INT32 @@ -202,7 +202,7 @@ def define_node( # Scale output back to 8 bit # pyre-ignore tqutils.insert_rescale_op_to_int8( - tosa_graph, add_output, scale_back, node, self.tosa_specs + tosa_graph, add_output, scale_back, node, self.tosa_spec ) # type: ignore[possibly-undefined] diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index c0839120821..a4318904c8e 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -98,7 +98,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # Rescale inputs to 32 bit rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) # Update IO diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 7a8f793e24b..5ff2aefa4db 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -97,7 +97,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # Rescale inputs to 32 bit rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) # Update IO diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index b640b9bc31d..230e42ea0ce 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -97,7 +97,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # Rescale inputs to 32 bit rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) # Update IO diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index a458ef126ee..3960c768ce3 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -97,7 +97,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # Rescale inputs to 32 bit rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) # Update IO diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 76b9a281c76..e0717f75246 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -97,7 +97,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # Rescale inputs to 32 bit rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) # Update IO diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index ed7afa4bfd8..99da0026a7f 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -129,7 +129,7 @@ def define_node( ) operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) output.shape = tosa_shape(output.shape, output.dim_order) @@ -155,5 +155,5 @@ def define_node( if output.dtype == ts.DType.INT8: # insert RESCALE from int32 back to int8 tqutils.insert_rescale_op_to_int8( - tosa_graph, max_output, scale_back, node, self.tosa_specs + tosa_graph, max_output, scale_back, node, self.tosa_spec ) diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index c0169e75910..82f3ea945a9 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -128,7 +128,7 @@ def define_node( ) operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) output.shape = tosa_shape(output.shape, output.dim_order) @@ -154,5 +154,5 @@ def define_node( if output.dtype == ts.DType.INT8: # insert RESCALE from int32 back to int8 tqutils.insert_rescale_op_to_int8( - tosa_graph, min_output, scale_back, node, self.tosa_specs + tosa_graph, min_output, scale_back, node, self.tosa_spec ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index c4c9c135e6e..789f9222ef7 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -189,14 +189,14 @@ def define_node( input_A, input_A_qargs.zp, [1.0], - tosa_spec=self.tosa_specs, + tosa_spec=self.tosa_spec, ) input_B_rescaled = tqutils.build_rescale_to_int32( tosa_graph, input_B, input_B_qargs.zp, [1.0], - tosa_spec=self.tosa_specs, + tosa_spec=self.tosa_spec, ) output_shape = tutils.tosa_shape(output.shape, output.dim_order) @@ -211,7 +211,7 @@ def define_node( ) output_scale = input_A_qargs.scale * input_B_qargs.scale tqutils.insert_rescale_op_to_int8( - tosa_graph, mul_output, output_scale, node, self.tosa_specs + tosa_graph, mul_output, output_scale, node, self.tosa_spec ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index b711b2f5056..cc3a5591a4c 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -168,7 +168,7 @@ def define_node( scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_specs + tosa_graph, inputs, node, self.tosa_spec ) else: # input[0].dtype == ts.DType.INT32 @@ -197,7 +197,7 @@ def define_node( # Scale output back to 8 bit # pyre-ignore tqutils.insert_rescale_op_to_int8( - tosa_graph, sub_output, scale_back, node, self.tosa_specs + tosa_graph, sub_output, scale_back, node, self.tosa_spec ) # type: ignore[possibly-undefined] diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 4eb08569005..dd81a0ef077 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -159,13 +159,11 @@ def define_node( # Rescale input to 32 bit rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( - tosa_graph, - [tensor], - node, + tosa_graph, [tensor], node, self.tosa_spec ) attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(tensor.dim_order.index(dim)) + attr.ReduceSumAttribute(tensor.dim_order.index(dim)) intermediate = tosa_graph.addIntermediate( tutils.tosa_shape(output_shape, tensor.dim_order), @@ -179,7 +177,9 @@ def define_node( attr, ) - tqutils.insert_rescale_op_to_int8(tosa_graph, intermediate, scale, node) + tqutils.insert_rescale_op_to_int8( + tosa_graph, intermediate, scale, node, self.tosa_spec + ) @register_node_visitor @@ -212,7 +212,7 @@ def define_node( output_shape[dim] = 1 # Output shape is input shape with dim reduced attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(tensor.dim_order.index(dim)) + attr.ReduceSumAttribute(tensor.dim_order.index(dim)) tosa_graph.addOperator( ts.TosaOp.Op().REDUCE_SUM, diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 96e9ab4e34a..10dc810da6b 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -32,7 +32,7 @@ def insert_rescale_ops_to_int32( tosa_graph: Any, inputs: list[TosaArg], node: Node, - tosa_spec=tosa_specification.Tosa_0_80, + tosa_spec=None, ) -> tuple[list[Any], float]: """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. The scales are adjusted using the smallest scale of all 'nodes'. @@ -79,7 +79,7 @@ def insert_rescale_op_to_int8( last_tensor: TosaArg, scale: float, node: Node, - tosa_spec=tosa_specification.Tosa_0_80, + tosa_spec=None, ) -> None: """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. Parameters: @@ -323,10 +323,11 @@ def build_rescale_to_int32( is_scale32: bool = True, is_double_round: bool = False, per_channel: bool = False, - tosa_spec=tosa_specification.Tosa_0_80, + tosa_spec=None, ) -> Any: input_A_rescaled_to_int32 = None - if tosa_spec == tosa_specification.Tosa_0_80: + if not tosa_spec or isinstance(tosa_spec, tosa_specification.Tosa_0_80): + # default to TOSA v0.80 until we switch to v1.0 import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore input_A_rescaled_to_int32 = tosa_fb.addIntermediate( @@ -343,7 +344,7 @@ def build_rescale_to_int32( output_zp=0, ) # type: ignore[call-arg] - elif isinstance(tosa_spec[0], tosa_specification.Tosa_1_00): + elif isinstance(tosa_spec, tosa_specification.Tosa_1_00): # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale import serializer.tosa_serializer as ts # type: ignore @@ -375,9 +376,10 @@ def build_rescale_from_int32( is_scale32: bool = True, is_double_round: bool = False, per_channel: bool = False, - tosa_spec=tosa_specification.Tosa_0_80, + tosa_spec=None, ) -> None: - if tosa_spec == tosa_specification.Tosa_0_80: + if not tosa_spec or isinstance(tosa_spec, tosa_specification.Tosa_0_80): + # default to TOSA v0.80 until we switch to v1.0 import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore build_rescale_v0_80( @@ -390,7 +392,7 @@ def build_rescale_from_int32( output_zp=output_zp, ) # type: ignore[call-arg] - elif isinstance(tosa_spec[0], tosa_specification.Tosa_1_00): + elif isinstance(tosa_spec, tosa_specification.Tosa_1_00): import serializer.tosa_serializer as ts # type: ignore # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs @@ -420,7 +422,7 @@ def build_rescale_conv_output( weight_scale: list[float], output_scale: list[float], output_zp: int, - tosa_spec=tosa_specification.Tosa_0_80, + tosa_spec=None, ): # TODO add check to verify if this is a Per-channel quantization. post_conv2d_scale = [ @@ -428,7 +430,8 @@ def build_rescale_conv_output( ] # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. - if tosa_spec == tosa_specification.Tosa_0_80: + if not tosa_spec or isinstance(tosa_spec, tosa_specification.Tosa_0_80): + # default to TOSA v0.80 until we switch to v1.0 build_rescale_v0_80( tosa_fb=tosa_fb, scale=post_conv2d_scale,