Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ python_library(
],
typing = True,
deps = [
"fbcode//executorch/backends/cadence/aot:utils",
"fbcode//caffe2:torch",
"fbcode//executorch/exir:scalar_type",
],
Expand Down
119 changes: 115 additions & 4 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
ScalarType.QINT32: torch.qint32,
}

_Number = bool | int | float


@impl(m, "quantize_per_tensor")
def quantize_per_tensor(
Expand Down Expand Up @@ -298,8 +296,7 @@ def quantized_layer_norm_per_tensor(
)


@impl(m, "quantized_conv_nchw")
def quantized_conv_nchw(
def quantized_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
Expand Down Expand Up @@ -374,6 +371,120 @@ def quantized_conv_nchw(
)


@impl(m, "quantized_conv_nchw")
def quantized_conv_nchw(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: tuple[int, int],
padding: tuple[int, int],
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
) -> torch.Tensor:
"""
Quantized convolution operation.

Args:
- input_tensor (Tensor): The activations tensor
- weight (Tensor): The weight tensor
- bias (Tensor): The bias tensor
- stride (Tuple[int]): The stride of the convolution
- padding (Tuple[int]): The padding of the convolution
- dilation (Tuple[int]): The dilation of the convolution
- groups (int): The number of groups
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- bias_scale (Tensor): The quantized bias scale
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
- out_multiplier (Tensor): Unused
- out_shift (Tensor): Unused
"""
if not input_tensor.is_contiguous(memory_format=torch.contiguous_format):
raise ValueError("Input tensor must be in NCHW format")
return quantized_conv(
input_tensor,
weight,
bias,
stride,
padding,
dilation,
groups,
in_zero_point,
weight_zero_point,
bias_scale,
output_scale,
output_zero_point,
out_multiplier,
out_shift,
)


@impl(m, "quantized_conv_nhwc")
def quantized_conv_nhwc(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: tuple[int, int],
padding: tuple[int, int],
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
) -> torch.Tensor:
"""
Quantized convolution operation.

Args:
- input_tensor (Tensor): The activations tensor
- weight (Tensor): The weight tensor
- bias (Tensor): The bias tensor
- stride (Tuple[int]): The stride of the convolution
- padding (Tuple[int]): The padding of the convolution
- dilation (Tuple[int]): The dilation of the convolution
- groups (int): The number of groups
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- bias_scale (Tensor): The quantized bias scale
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
- out_multiplier (Tensor): Unused
- out_shift (Tensor): Unused
"""

if not input_tensor.is_contiguous(memory_format=torch.channels_last):
raise ValueError("Input tensor must be in NHWC format")

return quantized_conv(
input_tensor,
weight,
bias,
stride,
padding,
dilation,
groups,
in_zero_point,
weight_zero_point,
bias_scale,
output_scale,
output_zero_point,
out_multiplier,
out_shift,
)


@impl(m, "requantize")
def requantize(
input: torch.Tensor,
Expand Down
Loading
Loading