diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index cc3a5591a4c..c65a2a3f43c 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -163,7 +163,11 @@ def define_node( validate_same_dtype(self.target, [*inputs, output]) # Handle int8 (quantized) and int32 - assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + supported_dtypes = [ts.DType.INT8, ts.DType.INT32] + if inputs[0].dtype not in supported_dtypes: + raise TypeError( + f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"' + ) scale_back = 1.0 if inputs[0].dtype == ts.DType.INT8: @@ -228,8 +232,15 @@ def define_node( super().define_node(node, tosa_graph, inputs, output) else: # FP32 Sub lowering - assert inputs[0].dtype == ts.DType.FP32 - assert output.dtype == ts.DType.FP32 + if ( + inputs[0].dtype != ts.DType.FP32 + or inputs[1].dtype != ts.DType.FP32 + or output.dtype != ts.DType.FP32 + ): + raise TypeError( + f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, " + f"input 2: {inputs[1].dtype} and output: {output.dtype}" + ) # MI lowering tosa_graph.addOperator( diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py index 3d3c47b7e84..4d8c0ff9320 100644 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ b/backends/arm/operators/op_upsample_bilinear2d.py @@ -153,9 +153,12 @@ def define_node( def in_int16_range(x): return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) - assert in_int16_range(scale_n_yx) - assert in_int16_range(scale_d_yx) - assert in_int16_range(border_yx) + if not in_int16_range(scale_n_yx): + raise ValueError("scale_n_yx is out of the int16 range") + if not in_int16_range(scale_d_yx): + raise ValueError("scale_d_yx is out of the int16 range") + if not in_int16_range(border_yx): + raise ValueError("border_yx is out of the int16 range") scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index d5f7b951e40..e9c72145555 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -102,9 +102,8 @@ def define_node( validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output]) - assert ( - inputs[0].shape is not None and output.shape is not None - ), "Only static shapes are supported" + if inputs[0].shape is None or output.shape is None: + raise ValueError("Only static shapes are supported") # tosa_shape output is NHWC, take HW input_size_yx = torch.tensor( @@ -121,9 +120,12 @@ def define_node( def in_int16_range(x): return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) - assert in_int16_range(scale_n_yx) - assert in_int16_range(scale_d_yx) - assert in_int16_range(border_yx) + if not in_int16_range(scale_n_yx): + raise ValueError("scale_n_yx is out of the int16 range") + if not in_int16_range(scale_d_yx): + raise ValueError("scale_d_yx is out of the int16 range") + if not in_int16_range(border_yx): + raise ValueError("border_yx is out of the int16 range") scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] scales_tensor = tosa_graph.addConst(