diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index dedd8307ed4..ee8eb08592a 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -14,7 +14,7 @@ from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape +from executorch.backends.arm.tosa_utils import tosa_shape from torch._export.utils import ( get_buffer, get_lifted_tensor_constant, @@ -33,7 +33,10 @@ def process_call_function( tosa_spec: TosaSpecification, ): # Unpack arguments and convert - inputs = getNodeArgs(node, tosa_spec) + try: + inputs = [TosaArg(arg, tosa_spec) for arg in node.args] + except ValueError as e: + raise ValueError(f"Failed processing args to op:\n{node}") from e # Convert output (this node itself) try: diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 7d544e46bfc..fec8f4337a2 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -15,7 +15,7 @@ import torch -from executorch.backends.arm.tosa_mapping import extract_tensor_meta, TosaArg +from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops @@ -26,13 +26,6 @@ logger = logging.getLogger(__name__) -def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]: - try: - return [TosaArg(arg, tosa_spec) for arg in node.args] - except ValueError as e: - raise ValueError(f"Failed processing args to op:\n{node}") from e - - def are_fake_tensors_broadcastable( fake_tensors: list[FakeTensor], ) -> tuple[bool, list[int]]: