Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def quantized_layer_norm_per_tensor(
)


def quantized_conv(
def quantized_conv_per_tensor(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
Expand All @@ -305,12 +305,12 @@ def quantized_conv(
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
"""
Quantized convolution operation.
Expand All @@ -324,19 +324,13 @@ def quantized_conv(
- dilation (Tuple[int]): The dilation of the convolution
- groups (int): The number of groups
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- bias_scale (Tensor): The quantized bias scale
- weight_zero_point (int): The quantized mapping of zero for the weight
- bias_scale (float): The quantized bias scale
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
- out_multiplier (Tensor): Unused
- out_shift (Tensor): Unused
- out_multiplier (int): Unused
- out_shift (int): Unused
"""
if weight_zero_point.view(-1).shape != (1,):
raise ValueError("Weight zero point must be a scalar")

if bias_scale.view(-1).shape != (1,):
raise ValueError("Bias scale must be a scalar")

if len(input_tensor.shape) == 3:
float_out = torch.nn.functional.conv1d(
(input_tensor - in_zero_point).float(),
Expand Down Expand Up @@ -371,8 +365,8 @@ def quantized_conv(
)


@impl(m, "quantized_conv_nchw")
def quantized_conv_nchw(
@impl(m, "quantized_conv_nchw_per_tensor")
def quantized_conv_nchw_per_tensor(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
Expand All @@ -381,12 +375,12 @@ def quantized_conv_nchw(
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
"""
Quantized convolution operation.
Expand All @@ -400,16 +394,16 @@ def quantized_conv_nchw(
- dilation (Tuple[int]): The dilation of the convolution
- groups (int): The number of groups
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- bias_scale (Tensor): The quantized bias scale
- weight_zero_point (int): The quantized mapping of zero for the weight
- bias_scale (float): The quantized bias scale
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
- out_multiplier (Tensor): Unused
- out_shift (Tensor): Unused
- out_multiplier (int): Unused
- out_shift (int): Unused
"""
if not input_tensor.is_contiguous(memory_format=torch.contiguous_format):
raise ValueError("Input tensor must be in NCHW format")
return quantized_conv(
return quantized_conv_per_tensor(
input_tensor,
weight,
bias,
Expand All @@ -427,8 +421,8 @@ def quantized_conv_nchw(
)


@impl(m, "quantized_conv_nhwc")
def quantized_conv_nhwc(
@impl(m, "quantized_conv_nhwc_per_tensor")
def quantized_conv_nhwc_per_tensor(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
Expand All @@ -437,12 +431,12 @@ def quantized_conv_nhwc(
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
"""
Quantized convolution operation.
Expand All @@ -456,18 +450,18 @@ def quantized_conv_nhwc(
- dilation (Tuple[int]): The dilation of the convolution
- groups (int): The number of groups
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- bias_scale (Tensor): The quantized bias scale
- weight_zero_point (int): The quantized mapping of zero for the weight
- bias_scale (float): The quantized bias scale
- output_scale (float): The scale of the output
- output_zero_point (int): The zero point of the output
- out_multiplier (Tensor): Unused
- out_shift (Tensor): Unused
- out_multiplier (int): Unused
- out_shift (int): Unused
"""

if not input_tensor.is_contiguous(memory_format=torch.channels_last):
raise ValueError("Input tensor must be in NHWC format")

return quantized_conv(
return quantized_conv_per_tensor(
input_tensor,
weight,
bias,
Expand Down
62 changes: 27 additions & 35 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
dequantize_per_tensor,
quantize_per_tensor,
quantized_add,
quantized_conv_nchw,
quantized_conv_nhwc,
quantized_conv_nchw_per_tensor,
quantized_conv_nhwc_per_tensor,
quantized_layer_norm_per_tensor,
quantized_linear,
quantized_relu,
Expand Down Expand Up @@ -356,8 +356,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
1, # groups
0, # in_zero_point
torch.tensor([0], dtype=torch.int8), # weight_zero_point
torch.tensor([1.0], dtype=torch.float32), # bias_scale
0, # weight_zero_point
1.0, # bias_scale
0.1, # output_scale
0, # output_zero_point
torch.tensor(
Expand Down Expand Up @@ -387,8 +387,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
1, # groups
0, # in_zero_point
torch.tensor([0], dtype=torch.int8), # weight_zero_point
torch.tensor([1.0], dtype=torch.float32), # bias_scale
0, # weight_zero_point
1.0, # bias_scale
0.25, # output_scale
0, # output_zero_point
typing.cast(None, torch.Tensor),
Expand Down Expand Up @@ -416,8 +416,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
1, # groups
128, # in_zero_point
torch.tensor([128], dtype=torch.uint8), # weight_zero_point
torch.tensor([0.1], dtype=torch.float32), # bias_scale
128, # weight_zero_point
0.1, # bias_scale
0.1, # output_scale
128, # output_zero_point
typing.cast(None, torch.Tensor),
Expand Down Expand Up @@ -447,8 +447,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation (padding for 2D, actual dilation is dilation[1])
1, # groups
0, # in_zero_point
torch.tensor([0], dtype=torch.int8), # weight_zero_point
torch.tensor([1.0], dtype=torch.float32), # bias_scale
0, # weight_zero_point
1.0, # bias_scale
0.5, # output_scale
0, # output_zero_point
typing.cast(None, torch.Tensor),
Expand Down Expand Up @@ -482,8 +482,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
1, # groups
0, # in_zero_point
torch.tensor([0], dtype=torch.int8), # weight_zero_point
torch.tensor([1.0], dtype=torch.float32), # bias_scale
0, # weight_zero_point
1.0, # bias_scale
0.2, # output_scale
0, # output_zero_point
typing.cast(None, torch.Tensor),
Expand Down Expand Up @@ -523,8 +523,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
1, # groups
0, # in_zero_point
torch.tensor([0], dtype=torch.int16), # weight_zero_point
torch.tensor([1.0], dtype=torch.float32), # bias_scale
0, # weight_zero_point
1.0, # bias_scale
0.1, # output_scale
0, # output_zero_point
typing.cast(None, torch.Tensor),
Expand Down Expand Up @@ -576,12 +576,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
1, # groups
0, # in_zero_point
torch.tensor(
[0], dtype=torch.int16
), # weight_zero_point for each output channel
torch.tensor(
[1.0], dtype=torch.float32
), # bias_scale for each channel
0, # weight_zero_point
1.0, # bias_scale
0.05, # output_scale
0, # output_zero_point
typing.cast(None, torch.Tensor),
Expand Down Expand Up @@ -623,12 +619,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
2, # groups (grouped convolution)
0, # in_zero_point
torch.tensor(
[0], dtype=torch.int8
), # weight_zero_point for each output channel
torch.tensor(
[1.0], dtype=torch.float32
), # bias_scale for each channel
0, # weight_zero_point
1.0, # bias_scale
0.2, # output_scale
0, # output_zero_point
typing.cast(None, torch.Tensor),
Expand Down Expand Up @@ -666,8 +658,8 @@ def test_quantized_layer_norm_per_tensor(
(1, 1), # dilation
1, # groups
0, # in_zero_point
torch.tensor([0], dtype=torch.int8), # weight_zero_point
torch.tensor([1.0], dtype=torch.float32), # bias_scale
0, # weight_zero_point
1.0, # bias_scale
0.5, # output_scale
0, # output_zero_point
typing.cast(None, torch.Tensor),
Expand All @@ -682,7 +674,7 @@ def test_quantized_layer_norm_per_tensor(
],
]
)
def test_quantized_conv(
def test_quantized_conv_per_tensor(
self,
input_tensor: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -692,12 +684,12 @@ def test_quantized_conv(
dilation: tuple[int, int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_multiplier: int,
out_shift: int,
dtype: torch.dtype,
expected_output: torch.Tensor,
memory_format: torch.memory_format,
Expand All @@ -710,9 +702,9 @@ def test_quantized_conv(
input_tensor = input_tensor.to(memory_format=memory_format)

conv = (
quantized_conv_nchw
quantized_conv_nchw_per_tensor
if memory_format == torch.contiguous_format
else quantized_conv_nhwc
else quantized_conv_nhwc_per_tensor
)

output = conv(
Expand Down
Loading