diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index a6d649fd92e..43929d3b1c8 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -39,6 +42,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) # Specification (0.80) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -105,6 +109,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) # Specification (0.80) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -157,6 +162,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + # Specification (1.0) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): @@ -224,6 +231,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + # Specification (1.0) states that input and output types # should all be the same if not (inputs[0].dtype == output.dtype): diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 11c32a3ae5f..fc8ecbb960a 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -40,6 +43,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -118,6 +122,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -169,6 +174,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -237,6 +244,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index 6bb9d563ca6..52cfbb18e81 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -9,6 +9,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -31,6 +34,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number @@ -71,6 +76,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 5c0fee5cfaf..d9f05c6f9f1 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -9,6 +9,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -31,6 +34,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number @@ -71,6 +76,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + input = inputs[0] dim = inputs[1].number diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index efb5b0b72b0..d8be68fbbc1 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore from torch.fx import Node @@ -30,6 +33,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." @@ -69,6 +74,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 727fd52dfd5..504de7319a2 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -85,6 +88,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8] if inputs[0].dtype not in supported_dtypes: raise TypeError( @@ -122,6 +127,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] if inputs[0].dtype not in supported_dtypes: raise TypeError( @@ -212,6 +219,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8] if inputs[0].dtype not in supported_dtypes: raise TypeError( @@ -252,6 +261,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4, 6]) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] if inputs[0].dtype not in supported_dtypes: raise TypeError( diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index ebc43ca33f6..8c68bde2006 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -17,7 +17,9 @@ NodeVisitor, register_node_visitor, ) - +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 from executorch.backends.arm.tosa_specification import TosaSpecification @@ -46,6 +48,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: raise TypeError( f"All IO needs to have the same data type, got: " @@ -128,6 +131,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: raise TypeError( f"All IO needs to have the same data type, got: " diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index bb77ba77940..c7bad9e4429 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [1, 2]) + tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number rank = len(output.shape) @@ -68,6 +73,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, [1, 2]) + tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number rank = len(output.shape) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index aedcc643e5d..566121d1bbb 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -15,6 +15,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -65,9 +68,6 @@ def cast_type(value: Any) -> int | float: # Attempt to cast to float return float(value) - if len(node.args) != 2 and len(node.args) != 3: - raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}") - min_arg = dtype_min max_arg = dtype_max @@ -87,10 +87,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) min_int8, max_int8 = self._get_min_max_arguments( node, @@ -130,10 +127,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) if inputs[0].dtype == ts.DType.INT8: # Call the inherited define_node for handling integers @@ -178,9 +172,6 @@ def cast_type(value: Any) -> int | float: # Attempt to cast to float return float(value) - if len(node.args) != 2 and len(node.args) != 3: - raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}") - min_arg = dtype_min max_arg = dtype_max @@ -202,10 +193,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments min_int8, max_int8 = self._get_min_max_arguments( @@ -247,10 +235,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, [2, 3]) min_fp32, max_fp32 = self._get_min_max_arguments( node, diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 75cdc0b0fc4..57c13664e76 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -39,6 +42,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) qargs = input_qparams[0] @@ -98,9 +103,10 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) qargs = input_qparams[0] diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 6f91c181bd2..fd35439d64a 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -17,6 +17,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 from executorch.backends.arm.tosa_specification import TosaSpecification @@ -67,6 +70,7 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore input, weight, bias, stride, pad, dilation, _, _, group = inputs + validate_num_inputs(self.target, inputs, 9) # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() @@ -242,6 +246,7 @@ def define_node( from tosa.RoundingMode import RoundingMode # type: ignore input, weight, bias, stride, pad, dilation, _, _, group = inputs + validate_num_inputs(self.target, inputs, 9) # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index 1fee25511ce..43fa26176e5 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -33,10 +36,7 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index a36a1f1b0cd..4cfa6012145 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator EQ but got " @@ -89,6 +94,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator EQ but got " diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index e174069ee77..bfce5c26699 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." @@ -63,6 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 60cc727d149..b23973a20a9 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,10 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input dtype: " @@ -66,10 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input dtype: " diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index c929f5f9c87..9c4425857f8 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GE but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GE but got " diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 53196a0d03c..638dee7ccfc 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GT but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator GT but got " diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index d927c1ba0db..bc7751c90dc 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LE but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LE but got " diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index b08bbcec003..9b4ef4c7b73 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,10 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " @@ -66,10 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 2e49eda7d98..02ca0d4d263 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LT but got " @@ -88,6 +93,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype: raise TypeError( "All inputs need to have the same data type for operator LT but got " diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 928262aefc5..40f48d3896f 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -41,6 +44,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4]) + input_tensor = inputs[0] kernel_size = inputs[1].special stride = inputs[2].special @@ -109,6 +114,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [3, 4]) + input_tensor = inputs[0] kernel_size = inputs[1].special stride = inputs[2].special diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 983ac5ded6d..5d5c56b90f8 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -17,7 +17,9 @@ NodeVisitor, register_node_visitor, ) - +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape @@ -46,6 +48,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " @@ -113,6 +117,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore from tosa.NanPropagationMode import NanPropagationMode # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index f39e2ce6d61..85c9b4ac3ed 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -16,6 +16,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape @@ -44,6 +47,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " @@ -111,6 +116,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore from tosa.NanPropagationMode import NanPropagationMode # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype != inputs[1].dtype and inputs[0].dtype != output.dtype: raise TypeError( f"Data type of inputs and output must be the same. Got input 0 dtype: " diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 6c5b94f1a2b..7d84be213b9 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -19,6 +19,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import reshape_for_broadcast @@ -42,6 +45,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if ( inputs[0].dtype != ts.DType.INT8 or inputs[1].dtype != ts.DType.INT8 @@ -122,6 +127,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) @@ -152,6 +159,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if ( inputs[0].dtype != ts.DType.INT8 or inputs[1].dtype != ts.DType.INT8 @@ -218,6 +227,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index f3ea8b00961..b78ee94b774 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -105,6 +108,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) @@ -142,6 +147,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 781fce3c79f..0b9ba6321f7 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -36,6 +39,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." @@ -77,6 +82,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 7d1ee951993..d8888ec9d49 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -35,10 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " @@ -69,10 +70,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 979a10ecff1..1ed42b23b9e 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape @@ -34,6 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + multiples = inputs[1].special attr = ts.TosaSerializerAttribute() @@ -61,6 +66,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + multiples = inputs[1].special if len(multiples) == 0: diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index 3c9abe1ba57..52953db24d0 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale @@ -35,6 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 5) + input_dtype = node.all_input_nodes[0].meta["val"].dtype output_dtype = cast(torch.dtype, node.args[1]) scale = cast(float, node.args[2]) @@ -91,6 +96,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore from tosa.RoundingMode import RoundingMode # type: ignore + validate_num_inputs(self.target, inputs, 5) + input_dtype = node.all_input_nodes[0].meta["val"].dtype output_dtype = cast(torch.dtype, node.args[1]) scale = cast(float, node.args[2]) diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 375dd76ba8d..e843f669a58 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00 @@ -32,6 +35,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + attr = ts.TosaSerializerAttribute() round = False if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: @@ -63,6 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + attr = ts.TosaSerializerAttribute() round = False if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions: diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 784c4b4d257..53156e9249a 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -35,10 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " @@ -67,10 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got " diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index a43e9ae798f..2881fc02eb5 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -10,6 +10,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,10 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " @@ -66,10 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index ee444c38f37..e082f6cb7a4 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -33,10 +36,8 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index a8d326cfa9b..412e3cca922 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -11,6 +11,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from torch.fx import Node @@ -47,6 +50,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [4, 5]) + # See slice_copy_support.py if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): raise ValueError("Unsupported combination of inputs") @@ -99,6 +104,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, [4, 5]) + # See slice_copy_support.py if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): raise ValueError("Unsupported combination of inputs") diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 65126f4d4dc..03c930918d7 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -40,6 +43,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -113,6 +118,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (0.80) states that input and output types # should all be the same if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: @@ -167,6 +174,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same assert inputs[0].dtype == inputs[1].dtype == output.dtype @@ -228,6 +237,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + # Specification (1.0) states that input and output types # should all be the same assert inputs[0].dtype == inputs[1].dtype == output.dtype diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index b898eb6cb67..f232136fd9b 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -40,6 +43,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + input_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) dim_list = [dim % len(input_shape) for dim in dim_list] @@ -98,6 +103,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) input_name = inputs[0].name @@ -151,6 +158,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + input_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) dim_list = [dim % len(input_shape) for dim in dim_list] @@ -210,6 +219,8 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) input_name = inputs[0].name diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index 454aebecd5e..350403f19bc 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." @@ -71,6 +76,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr] raise RuntimeError( f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}." diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 01af36c4d37..02727d0fabe 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -10,8 +10,12 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification + from torch.fx import Node @@ -34,10 +38,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " @@ -66,10 +68,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) + validate_num_inputs(self.target, inputs, 1) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: raise ValueError( f"Input and output for {self.target} need to be FP32, got input_dtype: " diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py index 210bfd2f61f..5dde6828f72 100644 --- a/backends/arm/operators/op_to_copy.py +++ b/backends/arm/operators/op_to_copy.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -39,6 +42,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) @@ -66,4 +71,6 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py index 740576f2736..d68bee88a64 100644 --- a/backends/arm/operators/op_to_dim_order_copy.py +++ b/backends/arm/operators/op_to_dim_order_copy.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -39,6 +42,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) @@ -66,4 +71,6 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 1) + tosa_graph.addOperator(ts.TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py index ac98979c234..8b0754fa079 100644 --- a/backends/arm/operators/op_transpose.py +++ b/backends/arm/operators/op_transpose.py @@ -13,6 +13,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -37,6 +40,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + output_rank = len(output.shape) perms = [dim % output_rank for dim in inputs[1].special] attr = ts.TosaSerializerAttribute() @@ -67,6 +72,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 2) + output_rank = len(output.shape) perms = [dim % output_rank for dim in inputs[1].special] attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py index 1c0fbc11d24..88149a7be91 100644 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ b/backends/arm/operators/op_upsample_bilinear2d.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80 from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape @@ -36,6 +39,8 @@ def define_node( import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore + validate_num_inputs(self.target, inputs, 4) + if inputs[0].shape is None or output.shape is None: raise ValueError("Only static shapes are supported") @@ -123,6 +128,8 @@ def define_node( from tosa.ResizeMode import ResizeMode # type: ignore from tosa.RoundingMode import RoundingMode # type: ignore + validate_num_inputs(self.target, inputs, 4) + if inputs[0].shape is None or output.shape is None: raise ValueError("Only static shapes are supported") diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index c08896c2cdc..da40859de74 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape @@ -36,6 +39,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + if inputs[0].shape is None or output.shape is None: raise ValueError("Only static shapes are supported") @@ -92,6 +97,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + assert ( inputs[0].shape is not None and output.shape is not None ), "Only static shapes are supported" diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index e8dedb65315..22a8146ecbd 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_utils import tosa_shape @@ -34,6 +37,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + attr = ts.TosaSerializerAttribute() new_shape = tosa_shape(inputs[1].special, output.dim_order) attr.ReshapeAttribute(new_shape) @@ -61,6 +66,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 2) + tosa_graph = cast(ts.TosaSerializer, tosa_graph) if len(output.shape) != 0: diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index b58fda1c399..d34f4134def 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -9,6 +9,10 @@ NodeVisitor, register_node_visitor, ) + +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from torch.fx import Node @@ -34,8 +38,7 @@ def _add_node_to_tosa_graph( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - if len(inputs) != 3: - raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + validate_num_inputs(self.target, inputs, 3) if inputs[0].dtype is not ts.DType.BOOL: raise ValueError("Input 0 needs to have dtype BOOL") @@ -66,6 +69,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + bi_supported_dtypes = [ ts.DType.INT8, ts.DType.INT16, @@ -94,6 +99,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + validate_num_inputs(self.target, inputs, 3) + mi_supported_dtypes = [ ts.DType.FP16, ts.DType.FP32, @@ -125,8 +132,7 @@ def _add_node_to_tosa_graph( ) -> None: import serializer.tosa_serializer as ts - if len(inputs) != 3: - raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + validate_num_inputs(self.target, inputs, 3) if inputs[0].dtype is not ts.DType.BOOL: raise ValueError("Input 0 needs to have dtype BOOL") @@ -157,6 +163,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + bi_supported_dtypes = [ ts.DType.INT8, ts.DType.INT16, @@ -185,6 +193,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 3) + mi_supported_dtypes = [ ts.DType.FP16, ts.DType.FP32, diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py new file mode 100644 index 00000000000..824695b4643 --- /dev/null +++ b/backends/arm/operators/operator_validation_utils.py @@ -0,0 +1,53 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + + +def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]): + """ + Validates the number of inputs provided to an operation against expected values. + + This function checks whether the length of the input list matches the expected + number(s) of inputs. + + Parameters: + ----------- + op_name : str + The name of the operation for which the inputs are being validated. + Used in the error message to provide context. + + inputs : List[TosaArg] + A list of inputs to be validated, where each input is assumed to be an + instance of `TosaArg`. + + expected : int or List[int] + The expected number of inputs. Can be either an integer or a list of integers. + + Raises: + ------- + ValueError + If the number of inputs does not match the expected value(s), a `ValueError` is + raised with a message indicating the operation name and the mismatch in expected + versus provided number of inputs. + + Example: + -------- + # Example usage: + from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + ) + + validate_num_inputs(self.target, inputs, [3, 4]) + + """ + if isinstance(expected, int): + expected = [expected] + if len(inputs) not in expected: + expected_str = ", ".join(map(str, expected)) + raise ValueError( + f"{op_name}: Expected number of input(s) to be " + f"[{expected_str}], got {len(inputs)}" + ) diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 425007bab3c..0a2f4419dfb 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -33,6 +36,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." @@ -62,6 +67,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 2) + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 0c41e13d445..cd5fa9956a3 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -14,6 +14,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -37,6 +40,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 1) + # Simply add an identityOp tosa_graph.addOperator( ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] @@ -69,6 +74,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts + validate_num_inputs(self.target, inputs, 1) + # Simply add an identityOp tosa_graph.addOperator( ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name] diff --git a/backends/arm/operators/ops_unary.py b/backends/arm/operators/ops_unary.py index 3bb2be16585..b7ba2df4277 100644 --- a/backends/arm/operators/ops_unary.py +++ b/backends/arm/operators/ops_unary.py @@ -12,6 +12,9 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) from executorch.backends.arm.tosa_mapping import TosaArg @@ -38,6 +41,8 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." @@ -76,6 +81,8 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore # noqa: F401 + validate_num_inputs(self.target, inputs, 1) + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype."