diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 92cc2b37479..6fb89d44210 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -111,13 +111,22 @@ def define_node( output: TosaArg, ) -> None: import serializer.tosa_serializer as ts + from executorch.backends.arm.tosa.specification import Tosa_1_00 validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [inputs[0], output], ts) + + # Build list of valid dtypes based on TOSA spec + valid_dtypes = [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32] + if isinstance( + output.tosa_spec, Tosa_1_00 + ) and output.tosa_spec.support_extension("int16"): + valid_dtypes.append(ts.DType.INT16) + validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + valid_dtypes, output.tosa_spec, )