diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index 0b474a0b077..91499a5a892 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -16,7 +16,10 @@ NodeVisitor, register_node_visitor, ) - +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -60,14 +63,12 @@ def define_node( ts.DType.FP32, } + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output]) + if inputs[0].dtype not in supported_dtypes: raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}") - if inputs[0].dtype != output.dtype: - raise ValueError( - "All inputs and output need same dtype." - f"Got {inputs[0].dtype=}, {output.dtype=}" - ) input_zp, output_zp = get_negate_zero_points( node, inputs[0].dtype == ts.DType.INT8 ) @@ -109,14 +110,12 @@ def define_node( ts.DType.FP32, } + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output]) + if inputs[0].dtype not in supported_dtypes: raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}") - if inputs[0].dtype != output.dtype: - raise ValueError( - "All inputs and output need same dtype." - f"Got {inputs[0].dtype=}, {output.dtype=}" - ) input_zp, output_zp = get_negate_zero_points( node, inputs[0].dtype == ts.DType.INT8 )