Skip to content
Merged
9 changes: 8 additions & 1 deletion backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops


Expand Down Expand Up @@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# Hardware specific constraints
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
# TODO remove this once TOSA 1.0 support for u55 is added.
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
return False
return True
else:
return self._is_node_supported_u55(node)
Expand Down
134 changes: 129 additions & 5 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List
from typing import Any, List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand All @@ -33,10 +32,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand All @@ -53,7 +55,7 @@ def define_node(
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
)
) # type: ignore[possibly-undefined]
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.abs
Expand Down Expand Up @@ -96,10 +98,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand Down Expand Up @@ -129,3 +134,122 @@ def define_node(
[output.name],
None,
)


@register_node_visitor
class AbsVisitor_INT(NodeVisitor):
target = "aten.abs.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
# Handle int8 (quantized) and int32
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
raise ValueError(
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
)

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node, self.tosa_specs
) # type: ignore[possibly-undefined]
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.abs
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
abs_output = output

# Do the INT32 Abs
tosa_graph.addOperator(
ts.TosaOp.Op().ABS,
[
rescaled_inputs[0].name,
],
[abs_output.name],
None,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph, abs_output, scale_back, node, self.tosa_specs
) # type: ignore[possibly-undefined]


@register_node_visitor
class AbsVisitor_FP(AbsVisitor_INT):
# inheriting 'target' from BI class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Abs lowering

if not (inputs[0].dtype == ts.DType.FP32):
raise ValueError(
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
)

if not (output.dtype == ts.DType.FP32):
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")

# MI lowering
tosa_graph.addOperator(
ts.TosaOp.Op().ABS,
[inputs[0].name],
[output.name],
None,
)
140 changes: 133 additions & 7 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

# pyre-unsafe

from typing import List
from typing import Any, List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand All @@ -34,10 +33,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand All @@ -58,7 +60,7 @@ def define_node(
if len(inputs[0].shape) > len(inputs[1].shape)
else inputs[1].dim_order
)

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
Expand Down Expand Up @@ -90,7 +92,9 @@ def define_node(
if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]
tqutils.insert_rescale_op_to_int8(
tosa_graph, add_output, scale_back, node
) # type: ignore[possibly-undefined]


@register_node_visitor
Expand All @@ -107,10 +111,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand All @@ -130,7 +137,7 @@ def define_node(
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
)

input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
input1, input2 = inputs

# MI lowering
tosa_graph.addOperator(
Expand All @@ -139,3 +146,122 @@ def define_node(
[output.name],
None,
)


@register_node_visitor
class AddVisitor_INT(NodeVisitor):
target = "aten.add.Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)
# Handle int8 (quantized) and int32
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
if inputs[0].dtype not in supported_dtypes:
raise TypeError(
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
)
scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node, self.tosa_specs
)
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.ADD
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
add_output = output

input1, input2 = rescaled_inputs

# Do the INT32 Add
tosa_graph.addOperator(
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[add_output.name],
None,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph, add_output, scale_back, node, self.tosa_specs
) # type: ignore[possibly-undefined]


@register_node_visitor
class AddVisitor_FP(AddVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Add lowering
if inputs[0].dtype != ts.DType.FP32:
raise TypeError(
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
)

input1, input2 = inputs

# FP lowering
tosa_graph.addOperator(
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[output.name],
None,
)
Loading
Loading