Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ python_library(
],
typing = True,
deps = [
"fbcode//executorch/backends/cadence/aot:utils",
"fbcode//caffe2:torch",
"fbcode//executorch/exir:scalar_type",
],
Expand Down
161 changes: 154 additions & 7 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

# pyre-strict


from typing import Optional

import torch

from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

Expand All @@ -21,10 +23,12 @@
ScalarType.QINT32: torch.qint32,
}

_Number = bool | int | float


@impl(m, "quantize_per_tensor")
def quantize_per_tensor(
input: torch.Tensor,
input_tensor: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
Expand All @@ -35,10 +39,10 @@ def quantize_per_tensor(
Quantizes a floating-point tensor to an integral tensor.

Args:
- input (Tensor): input tensor
- scale (float): Quantization scale. Derived from the ratio
- input_tensor (Tensor): input tensor
- scale (float): Inverse of quantization scale. Derived from the ratio
between the min/max of the floating-point tensor and the
min/max of the quantized range.
min/max of the quantized range, and then inverted.
- zero_point (int): The point which represents 0 in the quantized
range. For example, consider the floating point range [-1., 2.] and
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
Expand All @@ -61,7 +65,12 @@ def quantize_per_tensor(
raise ValueError(
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
)
return torch.round(input / scale + zero_point).to(dtype)

dequantized = torch.round(input_tensor * scale + zero_point).to(dtype)
return torch.max(
torch.min(dequantized, torch.tensor(quant_max)),
torch.tensor(quant_min),
)


@impl(m, "dequantize_per_tensor")
Expand Down Expand Up @@ -173,9 +182,16 @@ def quantized_add(
dequant_X = X_scale * (X - X_zero_point)
dequant_Y = Y_scale * (Y - Y_zero_point)

out_scale_inv = 1 / out_scale

# q_min/q_max are unused args
return quantize_per_tensor(
dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype
dequant_X + dequant_Y,
out_scale_inv,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
)


Expand Down Expand Up @@ -206,6 +222,7 @@ def quantized_linear(
- offset (Tensor): Unused
"""
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
out_scale_inv = 1 / out_scale

N, K = weight.shape

Expand All @@ -223,10 +240,140 @@ def quantized_linear(
src - in_zero_point, weight - weight_zero_point, bias
)
return quantize_per_tensor(
out, out_scale, out_zero_point, -128, 127, dtype
out,
out_scale_inv,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
).reshape(*leading_dims, N)


@impl(m, "quantized_layer_norm_per_tensor")
def quantized_layer_norm_per_tensor(
input_tensor: torch.Tensor,
X_scale: float,
X_zero_point: int,
normalized_shape: int,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
output_scale: float,
output_zero_point: int,
) -> torch.Tensor:
"""
Quantized layer norm operation.

Args:
- input_tensor (Tensor): The activations tensor
- X_scale (float): The scale of the input
- X_zero_point (int): The zero point of the input
- normalized_shape (int): The shape of the input
- weight (Tensor): The weight tensor
- bias (Tensor): The bias tensor
- eps (float): The epsilon value
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
"""
supported_dtypes = [torch.int8, torch.uint8]
if input_tensor.dtype not in supported_dtypes:
raise ValueError(
f"Input dtype must be one of {supported_dtypes}. Got {input_tensor.dtype}"
)

float_input_tensor = dequantize_per_tensor(
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
)
out = torch.nn.functional.layer_norm(
float_input_tensor, (normalized_shape,), weight, bias, eps=eps
)

return quantize_per_tensor(
out,
1 / output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)


@impl(m, "quantized_conv_nchw")
def quantized_conv_nchw(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: tuple[int, int],
padding: tuple[int, int],
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
) -> torch.Tensor:
"""
Quantized convolution operation.

Args:
- input_tensor (Tensor): The activations tensor
- weight (Tensor): The weight tensor
- bias (Tensor): The bias tensor
- stride (Tuple[int]): The stride of the convolution
- padding (Tuple[int]): The padding of the convolution
- dilation (Tuple[int]): The dilation of the convolution
- groups (int): The number of groups
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- bias_scale (Tensor): The quantized bias scale
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
- out_multiplier (Tensor): Unused
- out_shift (Tensor): Unused
"""
if weight_zero_point.view(-1).shape != (1,):
raise ValueError("Weight zero point must be a scalar")

if bias_scale.view(-1).shape != (1,):
raise ValueError("Bias scale must be a scalar")

if len(input_tensor.shape) == 3:
float_out = torch.nn.functional.conv1d(
(input_tensor - in_zero_point).float(),
(weight - weight_zero_point).float(),
(bias * bias_scale).float(),
stride[1],
padding[1],
dilation[1],
groups,
)

elif len(input_tensor.shape) == 4:
float_out = torch.nn.functional.conv2d(
(input_tensor - in_zero_point).float(),
(weight - weight_zero_point).float(),
(bias * bias_scale).float(),
stride,
padding,
dilation,
groups,
)
else:
raise ValueError("Input tensor must be 3D or 4D")

return quantize_per_tensor(
float_out,
1.0 / output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
input_tensor.dtype,
)


@impl(m, "requantize")
def requantize(
input: torch.Tensor,
Expand Down
Loading
Loading