From 379570638d47d50baef7b7123dffa97a410f9a3b Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Mon, 7 Apr 2025 09:43:30 +0200 Subject: [PATCH] Arm backend: Convert assert to raise TypeError in op_sub Asserts are converted to proper raises to ensure graph integrity. Change-Id: I22b70b97b26f58b8e6bc351ab271a1b2eb1f4bba Signed-off-by: Sebastian Larsson --- backends/arm/operators/op_sub.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 6cd422095ab..06fc9a4ef72 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -41,9 +41,19 @@ def define_node( ) -> None: # Specification (0.80) states that input and output types # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) + # Handle int8 (quantized) and int32 - assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + supported_dtypes = [ts.DType.INT8, ts.DType.INT32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"' + ) if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( @@ -98,15 +108,27 @@ def define_node( ) -> None: # Specification (0.80) states that input and output types # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got input 1: " + f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: " + f"{output.dtype}" + ) if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output) else: # FP32 Sub lowering - assert inputs[0].dtype == ts.DType.FP32 - assert output.dtype == ts.DType.FP32 + if ( + inputs[0].dtype != ts.DType.FP32 + or inputs[1].dtype != ts.DType.FP32 + or output.dtype != ts.DType.FP32 + ): + raise TypeError( + f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, " + f"input 2: {inputs[1].dtype} and output: {output.dtype}" + ) # MI lowering tosa_graph.addOperator(