From f8f6005268f0de166f6a4474eda1081527de76ea Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Wed, 7 May 2025 13:16:58 +0200 Subject: [PATCH] Arm backend: Add validation steps to op_neg The validation steps replace the raises that previously verified the same thing. This reduces duplicated code. Signed-off-by: Sebastian Larsson Change-Id: I6bafc0ef332cd1b7dbf7474bc524637d734ec4fc --- backends/arm/operators/op_neg.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index 0b474a0b077..91499a5a892 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -16,7 +16,10 @@ NodeVisitor, register_node_visitor, ) - +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -60,14 +63,12 @@ def define_node( ts.DType.FP32, } + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output]) + if inputs[0].dtype not in supported_dtypes: raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}") - if inputs[0].dtype != output.dtype: - raise ValueError( - "All inputs and output need same dtype." - f"Got {inputs[0].dtype=}, {output.dtype=}" - ) input_zp, output_zp = get_negate_zero_points( node, inputs[0].dtype == ts.DType.INT8 ) @@ -109,14 +110,12 @@ def define_node( ts.DType.FP32, } + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output]) + if inputs[0].dtype not in supported_dtypes: raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}") - if inputs[0].dtype != output.dtype: - raise ValueError( - "All inputs and output need same dtype." - f"Got {inputs[0].dtype=}, {output.dtype=}" - ) input_zp, output_zp = get_negate_zero_points( node, inputs[0].dtype == ts.DType.INT8 )