|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import torch | 
|  | 8 | +from executorch.exir.dialects._ops import ( | 
|  | 9 | +    ops as exir_ops, | 
|  | 10 | +)  # To provide the implementation of the operators | 
|  | 11 | +from torch.library import impl, Library, register_fake | 
|  | 12 | + | 
|  | 13 | +# New operator library with a custom namespace to allow fusion etc. | 
|  | 14 | +lib = Library("cortex_m", "DEF") | 
|  | 15 | + | 
|  | 16 | +### | 
|  | 17 | +# dequantize_per_tensor | 
|  | 18 | +### | 
|  | 19 | + | 
|  | 20 | +lib.define( | 
|  | 21 | +    "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" | 
|  | 22 | +) | 
|  | 23 | + | 
|  | 24 | +lib.define( | 
|  | 25 | +    "quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" | 
|  | 26 | +) | 
|  | 27 | + | 
|  | 28 | + | 
|  | 29 | +@register_fake("cortex_m::quantize_per_tensor") | 
|  | 30 | +def quantize_per_tensor_meta( | 
|  | 31 | +    input: torch.Tensor, | 
|  | 32 | +    scale: float, | 
|  | 33 | +    zero_point: int, | 
|  | 34 | +    quant_min: int, | 
|  | 35 | +    quant_max: int, | 
|  | 36 | +    dtype: torch.dtype, | 
|  | 37 | +) -> torch.Tensor: | 
|  | 38 | +    return torch.empty_like(input, dtype=dtype) | 
|  | 39 | + | 
|  | 40 | + | 
|  | 41 | +@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd") | 
|  | 42 | +def quantize_per_tensor_impl( | 
|  | 43 | +    input: torch.Tensor, | 
|  | 44 | +    scale: float, | 
|  | 45 | +    zero_point: int, | 
|  | 46 | +    quant_min: int, | 
|  | 47 | +    quant_max: int, | 
|  | 48 | +    dtype: torch.dtype, | 
|  | 49 | +) -> torch.Tensor: | 
|  | 50 | +    """ | 
|  | 51 | +    The implementation of the quantize_per_tensor operator is the same as the | 
|  | 52 | +    quantize_per_tensor operator in the edge dialect. | 
|  | 53 | +    """ | 
|  | 54 | +    return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( | 
|  | 55 | +        input, scale, zero_point, quant_min, quant_max, dtype | 
|  | 56 | +    ) | 
|  | 57 | + | 
|  | 58 | + | 
|  | 59 | +### | 
|  | 60 | +# dequantize_per_tensor | 
|  | 61 | +### | 
|  | 62 | + | 
|  | 63 | +lib.define( | 
|  | 64 | +    "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" | 
|  | 65 | +) | 
|  | 66 | +lib.define( | 
|  | 67 | +    "dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" | 
|  | 68 | +) | 
|  | 69 | + | 
|  | 70 | + | 
|  | 71 | +@register_fake("cortex_m::dequantize_per_tensor") | 
|  | 72 | +def dequantize_per_tensor_meta( | 
|  | 73 | +    input: torch.Tensor, | 
|  | 74 | +    scale: float, | 
|  | 75 | +    zero_point: int, | 
|  | 76 | +    quant_min: int, | 
|  | 77 | +    quant_max: int, | 
|  | 78 | +    dtype: torch.dtype, | 
|  | 79 | +) -> torch.Tensor: | 
|  | 80 | +    return torch.empty_like(input, dtype=torch.float) | 
|  | 81 | + | 
|  | 82 | + | 
|  | 83 | +@impl(lib, "dequantize_per_tensor", "CompositeExplicitAutograd") | 
|  | 84 | +def dequantize_per_tensor_impl( | 
|  | 85 | +    input: torch.Tensor, | 
|  | 86 | +    scale: float, | 
|  | 87 | +    zero_point: int, | 
|  | 88 | +    quant_min: int, | 
|  | 89 | +    quant_max: int, | 
|  | 90 | +    dtype: torch.dtype, | 
|  | 91 | +) -> torch.Tensor: | 
|  | 92 | +    """ | 
|  | 93 | +    The implementation of the dequantize_per_tensor operator is the same as the | 
|  | 94 | +    dequantize_per_tensor operator in the edge dialect. | 
|  | 95 | +    """ | 
|  | 96 | +    return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( | 
|  | 97 | +        input, scale, zero_point, quant_min, quant_max, dtype | 
|  | 98 | +    ) | 
0 commit comments