diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 61f01cb7099..4c09ed91f16 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -50,7 +50,10 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.INT8, output.tosa_spec + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT32], + output.tosa_spec, ) dim_order = ( @@ -58,30 +61,39 @@ def define_node( if len(inputs[0].shape) > len(inputs[1].shape) else inputs[1].dim_order ) - input_A = inputs[0] - input_B = inputs[1] - input_qparams = get_input_qparams(node) - input_A_qargs = input_qparams[0] - input_B_qargs = input_qparams[1] - input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) - input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) - - # Rescale inputs to INT32 with zp=0 - input_A_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_A, - input_A_qargs.get_zp_per_tensor(), - 1.0, - ) - input_B_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_B, - input_B_qargs.get_zp_per_tensor(), - 1.0, - ) - - output_shape = tutils.tosa_shape(output.shape, output.dim_order) - mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + if inputs[0].dtype == ts.DType.INT8: + input_A = inputs[0] + input_B = inputs[1] + input_qparams = get_input_qparams(node) + input_A_qargs = input_qparams[0] + input_B_qargs = input_qparams[1] + input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) + input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) + + # Rescale inputs to INT32 with zp=0 + input_A_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_A, + input_A_qargs.get_zp_per_tensor(), + 1.0, + ) + input_B_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_B, + input_B_qargs.get_zp_per_tensor(), + 1.0, + ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.MUL + input_A_rescaled, input_B_rescaled = inputs[0], inputs[1] + + if output.dtype == ts.DType.INT8: + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + mul_output = output input1, input2 = tutils.reshape_for_broadcast( tosa_graph, @@ -101,10 +113,16 @@ def define_node( [mul_output.name], attr, ) - output_scale = ( - input_A_qargs.get_scale_per_tensor() * input_B_qargs.get_scale_per_tensor() - ) - tqutils.insert_rescale_op_to_int8(tosa_graph, mul_output, output_scale, node) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + output_scale = ( + input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] + * input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] + ) + tqutils.insert_rescale_op_to_int8( + tosa_graph, mul_output, output_scale, node + ) @register_node_visitor @@ -161,35 +179,47 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.INT8, output.tosa_spec - ) - - input_A = inputs[0] - input_B = inputs[1] - input_qparams = get_input_qparams(node) - input_A_qargs = input_qparams[0] - input_B_qargs = input_qparams[1] - input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) - input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) - - # Rescale inputs to INT32 with zp=0 - input_A_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_A, - input_A_qargs.get_zp_per_tensor(), - 1.0, - tosa_spec=self.tosa_spec, - ) - input_B_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_B, - input_B_qargs.get_zp_per_tensor(), - 1.0, - tosa_spec=self.tosa_spec, + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT32], + output.tosa_spec, ) - output_shape = tutils.tosa_shape(output.shape, output.dim_order) - mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + if inputs[0].dtype == ts.DType.INT8: + input_A = inputs[0] + input_B = inputs[1] + input_qparams = get_input_qparams(node) + input_A_qargs = input_qparams[0] + input_B_qargs = input_qparams[1] + input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) + input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) + + # Rescale inputs to INT32 with zp=0 + input_A_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_A, + input_A_qargs.get_zp_per_tensor(), + 1.0, + tosa_spec=self.tosa_spec, + ) + input_B_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_B, + input_B_qargs.get_zp_per_tensor(), + 1.0, + tosa_spec=self.tosa_spec, + ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.MUL + input_A_rescaled, input_B_rescaled = inputs[0], inputs[1] + + if output.dtype == ts.DType.INT8: + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + mul_output = output # Do the INT32 Mul tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") @@ -198,12 +228,16 @@ def define_node( [input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"], [mul_output.name], ) - output_scale = ( - input_A_qargs.get_scale_per_tensor() * input_B_qargs.get_scale_per_tensor() - ) - tqutils.insert_rescale_op_to_int8( - tosa_graph, mul_output, output_scale, node, self.tosa_spec - ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + output_scale = ( + input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] + * input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] + ) + tqutils.insert_rescale_op_to_int8( + tosa_graph, mul_output, output_scale, node, self.tosa_spec + ) @register_node_visitor diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index a4c0dd4a0f8..b061e57287a 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -79,6 +79,23 @@ } +test_data_suite_int32 = { + # (test_name, input, other,) See torch.mul() for info + "op_mul_rank4_randn_int32": lambda: ( + torch.randint(0, 10, (1, 10, 25, 20), dtype=torch.int32), + torch.randint(0, 10, (1, 10, 25, 20), dtype=torch.int32), + ), + "op_mul_rank4_randn_mutltiple_broadcasts_int32": lambda: ( + torch.randint(0, 10, (1, 4, 4, 1), dtype=torch.int32), + torch.randint(0, 10, (1, 1, 4, 4), dtype=torch.int32), + ), + "op_mul_rank4_randn_broadcast_int32": lambda: ( + torch.randint(0, 10, (1, 10, 25, 20), dtype=torch.int32), + torch.randint(0, 10, (1, 25, 20), dtype=torch.int32), + ), +} + + class Mul(torch.nn.Module): def forward( @@ -111,6 +128,17 @@ def test_mul_tensor_tosa_MI_diff_input_ranks(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite_int32) +def test_mul_tensor_tosa_MI_int32(test_data: torch.Tensor): + pipeline = TosaPipelineMI[input_t1]( + Mul(), + test_data(), + aten_op, + exir_op=[], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite_2) def test_mul_tensor_tosa_BI_diff_input_ranks(test_data: torch.Tensor): pipeline = TosaPipelineBI[input_t1]( @@ -133,6 +161,18 @@ def test_mul_tensor_tosa_BI(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite_int32) +def test_mul_tensor_tosa_BI_int32(test_data: torch.Tensor): + pipeline = TosaPipelineBI[input_t1]( + Mul(), + test_data(), + aten_op, + exir_op=[], + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 def test_mul_tensor_u55_BI(test_data: torch.Tensor): @@ -157,3 +197,47 @@ def test_mul_tensor_u85_BI(test_data: torch.Tensor): run_on_fvp=True, ) pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite_int32, + xfails={ + # TODO: MLETORCH-1132 Investigate why tests with inputs that require broadcasting fail on u55/u85 + "op_mul_rank4_randn_mutltiple_broadcasts_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int", + "op_mul_rank4_randn_broadcast_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int", + }, +) +@common.XfailIfNoCorstone300 +def test_mul_tensor_u55_BI_int32(test_data: torch.Tensor): + pipeline = EthosU55PipelineBI[input_t1]( + Mul(), + test_data(), + aten_op, + exir_ops=[], + run_on_fvp=True, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite_int32, + xfails={ + # TODO: MLETORCH-1132 Investigate why tests with inputs that require broadcasting fail on u55/u85 + "op_mul_rank4_randn_mutltiple_broadcasts_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int", + "op_mul_rank4_randn_broadcast_int32": "RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Int", + }, +) +@common.XfailIfNoCorstone320 +def test_mul_tensor_u85_BI_int32(test_data: torch.Tensor): + pipeline = EthosU85PipelineBI[input_t1]( + Mul(), + test_data(), + aten_op, + exir_ops=[], + run_on_fvp=True, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run()