Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from copy import copy
from typing import cast, Dict, Optional, Set, Tuple, Type

Expand Down Expand Up @@ -93,6 +94,7 @@ class InsertRescaleInt32Pass(ArmPass):
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
]

def _int32_qargs(self, s):
Expand Down Expand Up @@ -133,6 +135,15 @@ def _get_inputs_rescaled_qparams(
qparams = {
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
}
elif target in [
exir_ops.edge.aten.mul.Tensor,
]:
# The input scales do not need to be adjusted for these ops; they
# can remain the same.
qparams = {
i: self._int32_qargs(input_qparams[i].get_scale_per_tensor())
for i in range(len(input_qparams))
}
else:
raise ValueError(f"Not a valid target: {target}")

Expand Down Expand Up @@ -161,6 +172,20 @@ def _get_output_qparams(
]:
# Output is bool for these ops and thus no qparams are present
return None
elif target in [exir_ops.edge.aten.mul.Tensor]:
# Mul will cause the scales to also multiply; refer to the formula
# where we compute the output scale S_2:
#
# (Q_2 - ZP_2) * S_2 == ((Q_0 - ZP_0) * S_0) * ((Q_1 - ZP_1) * S_1)
#
# yields:
#
# (Q_2 - ZP_2) == (Q_0 - ZP_0) * (Q_1 - ZP_1)
# S_2 = S_0 * S_1
output_scale = math.prod(
(qp.get_scale_per_tensor() for qp in inputs_qparams.values())
)
return self._int32_qargs(output_scale)
else:
raise ValueError(f"Not a valid target: {target}")

Expand All @@ -187,7 +212,7 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
modified = False
for i in qargs:
qp = qargs[i]
if qp.dtype != torch.int8:
if qp.dtype not in (torch.int8, torch.int16):
continue

arg_node = args_copy[i]
Expand Down Expand Up @@ -226,7 +251,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
assert rescale_qargs is not None

qarg = qargs[0]
if qarg.dtype != torch.int8:
if qarg.dtype not in (torch.int8, torch.int16):
return False

users_copy = list(node.users)
Expand All @@ -237,7 +262,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
exir_ops.backend.tosa.RESCALE.default,
(
node,
torch.int8,
qarg.dtype,
rescale_qargs.get_scale_per_tensor()
/ qarg.get_scale_per_tensor(), # Old scale / new scale
rescale_qargs.get_zp_per_tensor(), # Old zero point
Expand Down
110 changes: 6 additions & 104 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,8 @@

from typing import Any, List

import executorch.backends.arm.tosa.quant_utils as tqutils
import executorch.backends.arm.tosa.utils as tutils
import torch

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand All @@ -24,17 +18,17 @@
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import TosaSpecification


@register_node_visitor
class MulVisitor_INT(NodeVisitor):
class MulVisitor(NodeVisitor):
target = "aten.mul.Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
]

def define_node(
Expand All @@ -52,105 +46,13 @@ def define_node(
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)

if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
input_A = inputs[0]
input_B = inputs[1]
input_qparams = get_input_qparams(node)
input_A_qargs = input_qparams[0]
input_B_qargs = input_qparams[1]
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)

# Rescale inputs to INT32 with zp=0
input_A_rescaled = tqutils.build_rescale_to_int32(
tosa_graph,
input_A,
input_A_qargs.get_zp_per_tensor(),
1.0,
tosa_spec=self.tosa_spec,
)
input_B_rescaled = tqutils.build_rescale_to_int32(
tosa_graph,
input_B,
input_B_qargs.get_zp_per_tensor(),
1.0,
tosa_spec=self.tosa_spec,
)
else:
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
# Non quantized input, natively support by TOSA.MUL
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]

if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32 (non-quantized)
mul_output = output

# Do the INT32 Mul
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
self._serialize_operator(
node,
tosa_graph,
ts.TosaOp.Op().MUL,
[input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"],
[mul_output.name],
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
output_scale = (
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
)
tqutils.insert_rescale_op_to_int8(
tosa_graph, mul_output, output_scale, node, self.tosa_spec
)
elif output.dtype == ts.DType.INT16:
# Scale output back to 16 bit
output_scale = (
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
)
tqutils.insert_rescale_op_to_int16(
tosa_graph, mul_output, output_scale, node, self.tosa_spec
)


@register_node_visitor
class MulVisitor_FP(MulVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs, output], ts)

if inputs[0].dtype == ts.DType.INT8:
return super().define_node(node, tosa_graph, inputs, output)

input1, input2 = inputs

tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
self._serialize_operator(
node,
tosa_graph,
tosa_graph.addOperator(
ts.TosaOp.Op().MUL,
[input1.name, input2.name, f"{node.name}_shift"],
[inputs[0].name, inputs[1].name, f"{node.name}_shift"],
[output.name],
)
13 changes: 7 additions & 6 deletions backends/arm/test/passes/test_insert_rescale_i32_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def __init__(self):
super().__init__()

def forward(self, x, y):
a = torch.maximum(x, y)
b = torch.abs(a)
c = a > b
return c
a = x * y
b = torch.maximum(a, y)
c = torch.abs(b)
d = c > b
return d

def get_inputs(self, dtype) -> input_t:
if dtype == torch.float32:
Expand All @@ -45,8 +46,8 @@ def test_insert_rescales():
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
ops_after = {
# "number of op nodes with i8 output" + "number of i8 node inputs"
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2
+ 5,
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
+ 7,
}
pipeline = PassPipeline[input_t](
module,
Expand Down
Loading