diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 1f701f29b1e..ef9ed31c88d 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -40,9 +40,19 @@ def define_node( ) -> None: # Specification (0.80) states that input and output types # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) + # Handle int8 (quantized) and int32 - assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + supported_dtypes = [ts.DType.INT8, ts.DType.INT32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"' + ) if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( @@ -97,15 +107,27 @@ def define_node( ) -> None: # Specification (0.80) states that input and output types # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output) else: # FP32 Sub lowering - assert inputs[0].dtype == ts.DType.FP32 - assert output.dtype == ts.DType.FP32 + if ( + inputs[0].dtype != ts.DType.FP32 + or inputs[1].dtype != ts.DType.FP32 + or output.dtype != ts.DType.FP32 + ): + raise TypeError( + f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, " + f"input 2: {inputs[1].dtype} and output: {output.dtype}" + ) # MI lowering tosa_graph.addOperator(