|
| 1 | +# Copyright 2023-2024 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +# pyre-unsafe |
| 7 | +from typing import List |
| 8 | + |
| 9 | +import serializer.tosa_serializer as ts |
| 10 | +import torch |
| 11 | +from executorch.backends.arm.operators.node_visitor import ( |
| 12 | + NodeVisitor, |
| 13 | + register_node_visitor, |
| 14 | +) |
| 15 | +from executorch.backends.arm.tosa_mapping import TosaArg |
| 16 | +from executorch.backends.arm.tosa_specification import TosaSpecification |
| 17 | +from executorch.backends.arm.tosa_utils import tosa_shape |
| 18 | +from serializer.tosa_serializer import TosaOp |
| 19 | + |
| 20 | + |
| 21 | +@register_node_visitor |
| 22 | +class DivVisitor(NodeVisitor): |
| 23 | + target = "aten.div.Tensor" |
| 24 | + |
| 25 | + # Only supported for MI |
| 26 | + tosa_specs = [ |
| 27 | + TosaSpecification.create_from_string("TOSA-0.80+MI"), |
| 28 | + ] |
| 29 | + |
| 30 | + def __init__(self, *args): |
| 31 | + super().__init__(*args) |
| 32 | + |
| 33 | + def define_node( |
| 34 | + self, |
| 35 | + node: torch.fx.Node, |
| 36 | + tosa_graph: ts.TosaSerializer, |
| 37 | + inputs: List[TosaArg], |
| 38 | + output: TosaArg, |
| 39 | + is_quant_node: bool, |
| 40 | + ) -> None: |
| 41 | + # FP32 Div is implemented as output=x/y -> output=x*1/y e.g. MUL(x,RECIPROCAL(y)) |
| 42 | + recip = tosa_graph.addIntermediate( |
| 43 | + tosa_shape(inputs[1].shape, inputs[1].dim_order), inputs[1].dtype |
| 44 | + ) |
| 45 | + tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name]) |
| 46 | + |
| 47 | + attr = ts.TosaSerializerAttribute() |
| 48 | + attr.MulAttribute(0) |
| 49 | + tosa_graph.addOperator( |
| 50 | + TosaOp.Op().MUL, [inputs[0].name, recip.name], [output.name], attr |
| 51 | + ) |
0 commit comments