|  | 
|  | 1 | +# Copyright 2025 Arm Limited and/or its affiliates. | 
|  | 2 | +# | 
|  | 3 | +# This source code is licensed under the BSD-style license found in the | 
|  | 4 | +# LICENSE file in the root directory of this source tree. | 
|  | 5 | + | 
|  | 6 | +# pyre-unsafe | 
|  | 7 | +from typing import List | 
|  | 8 | + | 
|  | 9 | +import torch | 
|  | 10 | + | 
|  | 11 | +import tosa_tools.v0_80.serializer.tosa_serializer as ts  # type: ignore | 
|  | 12 | + | 
|  | 13 | +from executorch.backends.arm.operators.node_visitor import ( | 
|  | 14 | +    NodeVisitor, | 
|  | 15 | +    register_node_visitor, | 
|  | 16 | +) | 
|  | 17 | +from executorch.backends.arm.tosa_mapping import TosaArg | 
|  | 18 | +from executorch.backends.arm.tosa_quant_utils import build_rescale | 
|  | 19 | +from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape | 
|  | 20 | +from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode  # type: ignore | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +@register_node_visitor | 
|  | 24 | +class UpsampleBilinear2dVisitor_0_80(NodeVisitor): | 
|  | 25 | +    target = "aten.upsample_bilinear2d.vec" | 
|  | 26 | + | 
|  | 27 | +    def __init__(self, *args): | 
|  | 28 | +        super().__init__(*args) | 
|  | 29 | + | 
|  | 30 | +    def define_node( | 
|  | 31 | +        self, | 
|  | 32 | +        node: torch.fx.Node, | 
|  | 33 | +        tosa_graph: ts.TosaSerializer, | 
|  | 34 | +        inputs: List[TosaArg], | 
|  | 35 | +        output: TosaArg, | 
|  | 36 | +    ) -> None: | 
|  | 37 | +        assert ( | 
|  | 38 | +            inputs[0].shape is not None and output.shape is not None | 
|  | 39 | +        ), "Only static shapes are supported" | 
|  | 40 | + | 
|  | 41 | +        input_dtype = inputs[0].dtype | 
|  | 42 | + | 
|  | 43 | +        # tosa_shape output is NHWC, take HW | 
|  | 44 | +        input_size_yx = torch.tensor( | 
|  | 45 | +            tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3] | 
|  | 46 | +        ) | 
|  | 47 | +        # Ignore scale and size parameters, directly use the output size as | 
|  | 48 | +        # we only support static shapes currently | 
|  | 49 | +        output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3]) | 
|  | 50 | + | 
|  | 51 | +        scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( | 
|  | 52 | +            input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True | 
|  | 53 | +        ) | 
|  | 54 | + | 
|  | 55 | +        def in_int16_range(x): | 
|  | 56 | +            return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) | 
|  | 57 | + | 
|  | 58 | +        assert in_int16_range(scale_n_yx) | 
|  | 59 | +        assert in_int16_range(scale_d_yx) | 
|  | 60 | +        assert in_int16_range(border_yx) | 
|  | 61 | + | 
|  | 62 | +        attr = ts.TosaSerializerAttribute() | 
|  | 63 | +        attr.ResizeAttribute( | 
|  | 64 | +            scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]], | 
|  | 65 | +            offset=offset_yx.tolist(), | 
|  | 66 | +            border=border_yx.tolist(), | 
|  | 67 | +            mode=ResizeMode.BILINEAR, | 
|  | 68 | +        ) | 
|  | 69 | + | 
|  | 70 | +        if input_dtype == output.dtype == ts.DType.FP32: | 
|  | 71 | +            tosa_graph.addOperator( | 
|  | 72 | +                ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr | 
|  | 73 | +            ) | 
|  | 74 | +            return | 
|  | 75 | +        elif input_dtype == output.dtype == ts.DType.INT8: | 
|  | 76 | +            intermediate = tosa_graph.addIntermediate( | 
|  | 77 | +                tosa_shape(output.shape, output.dim_order), ts.DType.INT32 | 
|  | 78 | +            ) | 
|  | 79 | + | 
|  | 80 | +            tosa_graph.addOperator( | 
|  | 81 | +                ts.TosaOp.Op().RESIZE, [inputs[0].name], [intermediate.name], attr | 
|  | 82 | +            ) | 
|  | 83 | + | 
|  | 84 | +            final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) | 
|  | 85 | + | 
|  | 86 | +            build_rescale( | 
|  | 87 | +                tosa_fb=tosa_graph, | 
|  | 88 | +                scale=[final_output_scale], | 
|  | 89 | +                input_node=intermediate, | 
|  | 90 | +                output_name=output.name, | 
|  | 91 | +                output_type=ts.DType.INT8, | 
|  | 92 | +                output_shape=output.shape, | 
|  | 93 | +                input_zp=0, | 
|  | 94 | +                output_zp=0, | 
|  | 95 | +                is_double_round=False, | 
|  | 96 | +            ) | 
|  | 97 | +        else: | 
|  | 98 | +            raise ValueError( | 
|  | 99 | +                "Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}" | 
|  | 100 | +            ) | 
0 commit comments