Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
import serializer.tosa_serializer as ts
from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_output import process_output
from executorch.backends.arm.operators.op_placeholder import process_placeholder

from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm._passes.arm_pass_manager import (
ArmPassManager,
) # usort: skip
from executorch.backends.arm.tosa_utils import (
dbg_fail,
dbg_tosa_dump,
from executorch.backends.arm.process_node import (
process_call_function,
process_output,
process_placeholder,
)
from executorch.backends.arm.tosa_utils import dbg_fail, dbg_tosa_dump
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
Expand Down
21 changes: 0 additions & 21 deletions backends/arm/operators/op_output.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,28 +1,68 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
#
from typing import cast, Dict

import numpy as np
import serializer.tosa_serializer as ts
import torch
import torch.fx
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_upstream,
get_quantized_node_output_dtype,
is_node_quantized,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
getNodeArgs,
is_bias_node_for_quantized_conv,
map_dtype,
tosa_shape,
)
from torch.export.exported_program import ExportedProgram


def process_call_function(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
node_visitors: Dict[str, NodeVisitor],
tosa_spec: TosaSpecification,
):
# Unpack arguments and convert
inputs = getNodeArgs(node)

# Convert output (this node itself)
output = TosaArg(node)

is_quant_node = is_node_quantized(node)
if is_quant_node:
output_dtype = map_dtype(get_quantized_node_output_dtype(node))
else:
output_dtype = output.dtype
tosa_graph.currRegion.currBasicBlock.addTensor(
output.name,
tosa_shape(output.shape, output.dim_order),
output_dtype,
)

# Visiting each Node
# pyre-ignore[16]: Undefined attribute.
if node.target.__name__ in node_visitors:
# pyre-ignore[16]: Undefined attribute.
node_visitors[node.target.__name__].define_node(
node,
tosa_graph,
inputs,
output,
is_quant_node,
)
else:
raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")


def process_inputs(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
Expand Down Expand Up @@ -176,3 +216,13 @@ def process_placeholder(
)
else:
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")


def process_output(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
):
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
tosa_graph.addOutputTensor(
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
)
46 changes: 2 additions & 44 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,18 @@

import logging
import os
from typing import Any, cast, Dict
from typing import Any, cast

import numpy as np
import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_downstream,
get_quant_arg_upstream,
get_quantized_node_output_dtype,
is_node_quantized,
q_op,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaOp
from torch.fx import Node
Expand Down Expand Up @@ -233,44 +229,6 @@ def tosa_shape(shape, dim_order):
return tuple([shape[dim] for dim in dim_order])


def process_call_function(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
node_visitors: Dict[str, NodeVisitor],
tosa_spec: TosaSpecification,
):
# Unpack arguments and convert
inputs = getNodeArgs(node)

# Convert output (this node itself)
output = TosaArg(node)

is_quant_node = is_node_quantized(node)
if is_quant_node:
output_dtype = map_dtype(get_quantized_node_output_dtype(node))
else:
output_dtype = output.dtype
tosa_graph.currRegion.currBasicBlock.addTensor(
output.name,
(tosa_shape(output.shape, output.dim_order)),
output_dtype,
)

# Visiting each Node
# pyre-ignore[16]: Undefined attribute.
if node.target.__name__ in node_visitors:
# pyre-ignore[16]: Undefined attribute.
node_visitors[node.target.__name__].define_node(
node,
tosa_graph,
inputs,
output,
is_quant_node,
)
else:
raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")


def expand_dims(
tosa_graph: ts.TosaSerializer,
input_node: TosaArg,
Expand Down
Loading