|
16 | 16 | NodeVisitor, |
17 | 17 | register_node_visitor, |
18 | 18 | ) |
19 | | - |
| 19 | +from executorch.backends.arm.operators.operator_validation_utils import ( |
| 20 | + validate_num_inputs, |
| 21 | + validate_same_dtype, |
| 22 | +) |
20 | 23 | from executorch.backends.arm.tosa_mapping import TosaArg |
21 | 24 |
|
22 | 25 |
|
@@ -60,14 +63,12 @@ def define_node( |
60 | 63 | ts.DType.FP32, |
61 | 64 | } |
62 | 65 |
|
| 66 | + validate_num_inputs(self.target, inputs, 1) |
| 67 | + validate_same_dtype(self.target, [*inputs, output]) |
| 68 | + |
63 | 69 | if inputs[0].dtype not in supported_dtypes: |
64 | 70 | raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}") |
65 | 71 |
|
66 | | - if inputs[0].dtype != output.dtype: |
67 | | - raise ValueError( |
68 | | - "All inputs and output need same dtype." |
69 | | - f"Got {inputs[0].dtype=}, {output.dtype=}" |
70 | | - ) |
71 | 72 | input_zp, output_zp = get_negate_zero_points( |
72 | 73 | node, inputs[0].dtype == ts.DType.INT8 |
73 | 74 | ) |
@@ -109,14 +110,12 @@ def define_node( |
109 | 110 | ts.DType.FP32, |
110 | 111 | } |
111 | 112 |
|
| 113 | + validate_num_inputs(self.target, inputs, 1) |
| 114 | + validate_same_dtype(self.target, [*inputs, output]) |
| 115 | + |
112 | 116 | if inputs[0].dtype not in supported_dtypes: |
113 | 117 | raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}") |
114 | 118 |
|
115 | | - if inputs[0].dtype != output.dtype: |
116 | | - raise ValueError( |
117 | | - "All inputs and output need same dtype." |
118 | | - f"Got {inputs[0].dtype=}, {output.dtype=}" |
119 | | - ) |
120 | 119 | input_zp, output_zp = get_negate_zero_points( |
121 | 120 | node, inputs[0].dtype == ts.DType.INT8 |
122 | 121 | ) |
|
0 commit comments