|
| 1 | +# Copyright 2024 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +# pyre-unsafe |
| 7 | + |
| 8 | +from typing import cast, List |
| 9 | + |
| 10 | +import executorch.backends.arm.tosa_quant_utils as tqutils |
| 11 | + |
| 12 | +import serializer.tosa_serializer as ts |
| 13 | +from executorch.backends.arm.operators.node_visitor import ( |
| 14 | + NodeVisitor, |
| 15 | + register_node_visitor, |
| 16 | +) |
| 17 | +from executorch.backends.arm.tosa_mapping import TosaArg |
| 18 | +from executorch.backends.arm.tosa_utils import tosa_shape |
| 19 | + |
| 20 | +from serializer.tosa_serializer import TosaOp |
| 21 | +from torch.fx import Node |
| 22 | + |
| 23 | + |
| 24 | +@register_node_visitor |
| 25 | +class MaxVisitor(NodeVisitor): |
| 26 | + target = "aten.maximum.default" |
| 27 | + |
| 28 | + def __init__(self, *args): |
| 29 | + super().__init__(*args) |
| 30 | + |
| 31 | + def define_node( |
| 32 | + self, |
| 33 | + node: Node, |
| 34 | + tosa_graph: ts.TosaSerializer, |
| 35 | + inputs: List[TosaArg], |
| 36 | + output: TosaArg, |
| 37 | + is_quant_node: bool, |
| 38 | + ) -> None: |
| 39 | + assert inputs[0].dtype == inputs[1].dtype |
| 40 | + |
| 41 | + input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"]) |
| 42 | + min_output = output |
| 43 | + |
| 44 | + if inputs[0].dtype == ts.DType.INT8: |
| 45 | + # insert RESCALEs to int32 |
| 46 | + x_scale = input_qparams[0].scale |
| 47 | + x_zp = input_qparams[0].zp |
| 48 | + |
| 49 | + y_scale = input_qparams[1].scale |
| 50 | + y_zp = input_qparams[1].zp |
| 51 | + |
| 52 | + assert ( |
| 53 | + x_zp == y_zp |
| 54 | + ), "Different zp for inputs, MAX should be quantized with shared quantization!" |
| 55 | + assert ( |
| 56 | + x_scale == y_scale |
| 57 | + ), "Different scale for input, MAX should be quantized with shared quantization!" |
| 58 | + |
| 59 | + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( |
| 60 | + tosa_graph, inputs, node |
| 61 | + ) |
| 62 | + |
| 63 | + output.shape = tosa_shape(output.shape, output.dim_order) |
| 64 | + min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) |
| 65 | + else: |
| 66 | + operand_inputs = inputs |
| 67 | + |
| 68 | + tosa_graph.addOperator( |
| 69 | + TosaOp.Op().MAXIMUM, |
| 70 | + [ |
| 71 | + operand_inputs[0].name, |
| 72 | + operand_inputs[1].name, |
| 73 | + ], |
| 74 | + [min_output.name], |
| 75 | + ) |
| 76 | + |
| 77 | + if output.dtype == ts.DType.INT8: |
| 78 | + # insert RESCALE from int32 back to int8 |
| 79 | + tqutils.insert_rescale_node_back_to_int8( |
| 80 | + tosa_graph, min_output, scale_back, node |
| 81 | + ) |
0 commit comments