diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 47c3c2d5e5e..06207611e09 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -18,18 +18,17 @@ import serializer.tosa_serializer as ts from executorch.backends.arm.arm_vela import vela_compile from executorch.backends.arm.operators.node_visitor import get_node_visitors -from executorch.backends.arm.operators.op_output import process_output -from executorch.backends.arm.operators.op_placeholder import process_placeholder from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm._passes.arm_pass_manager import ( ArmPassManager, ) # usort: skip -from executorch.backends.arm.tosa_utils import ( - dbg_fail, - dbg_tosa_dump, +from executorch.backends.arm.process_node import ( process_call_function, + process_output, + process_placeholder, ) +from executorch.backends.arm.tosa_utils import dbg_fail, dbg_tosa_dump from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram diff --git a/backends/arm/operators/op_output.py b/backends/arm/operators/op_output.py deleted file mode 100644 index 1b053b18edc..00000000000 --- a/backends/arm/operators/op_output.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import cast - -import serializer.tosa_serializer as ts -import torch - - -def process_output( - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, -): - for output in cast(tuple[torch.fx.Node, ...], node.args[0]): - tosa_graph.addOutputTensor( - tosa_graph.currRegion.currBasicBlock.tensors[output.name] - ) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/process_node.py similarity index 78% rename from backends/arm/operators/op_placeholder.py rename to backends/arm/process_node.py index d466a13e385..9a3874c37e9 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/process_node.py @@ -1,14 +1,16 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -# pyre-unsafe +# +from typing import cast, Dict import numpy as np import serializer.tosa_serializer as ts +import torch import torch.fx -from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.operators.node_visitor import NodeVisitor +from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( get_quant_arg_upstream, get_quantized_node_output_dtype, @@ -16,13 +18,51 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import ( + getNodeArgs, is_bias_node_for_quantized_conv, - map_dtype, tosa_shape, ) from torch.export.exported_program import ExportedProgram +def process_call_function( + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + node_visitors: Dict[str, NodeVisitor], + tosa_spec: TosaSpecification, +): + # Unpack arguments and convert + inputs = getNodeArgs(node) + + # Convert output (this node itself) + output = TosaArg(node) + + is_quant_node = is_node_quantized(node) + if is_quant_node: + output_dtype = map_dtype(get_quantized_node_output_dtype(node)) + else: + output_dtype = output.dtype + tosa_graph.currRegion.currBasicBlock.addTensor( + output.name, + tosa_shape(output.shape, output.dim_order), + output_dtype, + ) + + # Visiting each Node + # pyre-ignore[16]: Undefined attribute. + if node.target.__name__ in node_visitors: + # pyre-ignore[16]: Undefined attribute. + node_visitors[node.target.__name__].define_node( + node, + tosa_graph, + inputs, + output, + is_quant_node, + ) + else: + raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") + + def process_inputs( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, @@ -176,3 +216,13 @@ def process_placeholder( ) else: raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") + + +def process_output( + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, +): + for output in cast(tuple[torch.fx.Node, ...], node.args[0]): + tosa_graph.addOutputTensor( + tosa_graph.currRegion.currBasicBlock.tensors[output.name] + ) diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 35ee6ef6b3b..1ae319e0cd7 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,22 +7,18 @@ import logging import os -from typing import Any, cast, Dict +from typing import Any, cast import numpy as np import serializer.tosa_serializer as ts import torch -from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( get_quant_arg_downstream, get_quant_arg_upstream, - get_quantized_node_output_dtype, - is_node_quantized, q_op, ) -from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -233,44 +229,6 @@ def tosa_shape(shape, dim_order): return tuple([shape[dim] for dim in dim_order]) -def process_call_function( - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - node_visitors: Dict[str, NodeVisitor], - tosa_spec: TosaSpecification, -): - # Unpack arguments and convert - inputs = getNodeArgs(node) - - # Convert output (this node itself) - output = TosaArg(node) - - is_quant_node = is_node_quantized(node) - if is_quant_node: - output_dtype = map_dtype(get_quantized_node_output_dtype(node)) - else: - output_dtype = output.dtype - tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, - (tosa_shape(output.shape, output.dim_order)), - output_dtype, - ) - - # Visiting each Node - # pyre-ignore[16]: Undefined attribute. - if node.target.__name__ in node_visitors: - # pyre-ignore[16]: Undefined attribute. - node_visitors[node.target.__name__].define_node( - node, - tosa_graph, - inputs, - output, - is_quant_node, - ) - else: - raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") - - def expand_dims( tosa_graph: ts.TosaSerializer, input_node: TosaArg,