diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 73a6713633a..727fd52dfd5 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -85,8 +85,12 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - input_tensor = inputs[0] - assert input_tensor.dtype == ts.DType.INT8 + supported_dtypes = [ts.DType.INT8] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) accumulator_type = ts.DType.INT32 @@ -118,9 +122,12 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - assert ( - inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 - ), "Only FP32 and INT8 supported" + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) if inputs[0].dtype == ts.DType.INT8: super().define_node(node, tosa_graph, inputs, output) @@ -205,8 +212,12 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - input_tensor = inputs[0] - assert input_tensor.dtype == ts.DType.INT8 + supported_dtypes = [ts.DType.INT8] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) accumulator_type = ts.DType.INT32 @@ -241,9 +252,12 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - assert ( - inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 - ), "Only FP32 and INT8 supported" + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f"IO data type needs to be one of {supported_dtypes}, got " + f'"{inputs[0].dtype}"' + ) if inputs[0].dtype == ts.DType.INT8: super().define_node(node, tosa_graph, inputs, output)