Skip to content

Commit d3e19eb

Browse files
authored
Backend-agnostic implementation of quantized_conv_nchw
Differential Revision: D81465757 Pull Request resolved: #13955
1 parent 41c299f commit d3e19eb

File tree

3 files changed

+416
-0
lines changed

3 files changed

+416
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ python_library(
129129
],
130130
typing = True,
131131
deps = [
132+
"fbcode//executorch/backends/cadence/aot:utils",
132133
"fbcode//caffe2:torch",
133134
"fbcode//executorch/exir:scalar_type",
134135
],

backends/cadence/aot/ref_implementations.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
# pyre-strict
88

9+
910
from typing import Optional
1011

1112
import torch
13+
1214
from executorch.exir.scalar_type import ScalarType
1315
from torch.library import impl, Library
1416

@@ -21,6 +23,8 @@
2123
ScalarType.QINT32: torch.qint32,
2224
}
2325

26+
_Number = bool | int | float
27+
2428

2529
@impl(m, "quantize_per_tensor")
2630
def quantize_per_tensor(
@@ -294,6 +298,82 @@ def quantized_layer_norm_per_tensor(
294298
)
295299

296300

301+
@impl(m, "quantized_conv_nchw")
302+
def quantized_conv_nchw(
303+
input_tensor: torch.Tensor,
304+
weight: torch.Tensor,
305+
bias: torch.Tensor,
306+
stride: tuple[int, int],
307+
padding: tuple[int, int],
308+
dilation: tuple[int, int],
309+
groups: int,
310+
in_zero_point: int,
311+
weight_zero_point: torch.Tensor,
312+
bias_scale: torch.Tensor,
313+
output_scale: float,
314+
output_zero_point: int,
315+
out_multiplier: torch.Tensor,
316+
out_shift: torch.Tensor,
317+
) -> torch.Tensor:
318+
"""
319+
Quantized convolution operation.
320+
321+
Args:
322+
- input_tensor (Tensor): The activations tensor
323+
- weight (Tensor): The weight tensor
324+
- bias (Tensor): The bias tensor
325+
- stride (Tuple[int]): The stride of the convolution
326+
- padding (Tuple[int]): The padding of the convolution
327+
- dilation (Tuple[int]): The dilation of the convolution
328+
- groups (int): The number of groups
329+
- in_zero_point (int): The quantized mapping of zero for the input
330+
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
331+
- bias_scale (Tensor): The quantized bias scale
332+
- output_scale (float): The scale of the output
333+
- output_zero_point (int): The zero point of the output
334+
- out_multiplier (Tensor): Unused
335+
- out_shift (Tensor): Unused
336+
"""
337+
if weight_zero_point.view(-1).shape != (1,):
338+
raise ValueError("Weight zero point must be a scalar")
339+
340+
if bias_scale.view(-1).shape != (1,):
341+
raise ValueError("Bias scale must be a scalar")
342+
343+
if len(input_tensor.shape) == 3:
344+
float_out = torch.nn.functional.conv1d(
345+
(input_tensor - in_zero_point).float(),
346+
(weight - weight_zero_point).float(),
347+
(bias * bias_scale).float(),
348+
stride[1],
349+
padding[1],
350+
dilation[1],
351+
groups,
352+
)
353+
354+
elif len(input_tensor.shape) == 4:
355+
float_out = torch.nn.functional.conv2d(
356+
(input_tensor - in_zero_point).float(),
357+
(weight - weight_zero_point).float(),
358+
(bias * bias_scale).float(),
359+
stride,
360+
padding,
361+
dilation,
362+
groups,
363+
)
364+
else:
365+
raise ValueError("Input tensor must be 3D or 4D")
366+
367+
return quantize_per_tensor(
368+
float_out,
369+
1.0 / output_scale,
370+
output_zero_point,
371+
torch.iinfo(input_tensor.dtype).min,
372+
torch.iinfo(input_tensor.dtype).max,
373+
input_tensor.dtype,
374+
)
375+
376+
297377
@impl(m, "requantize")
298378
def requantize(
299379
input: torch.Tensor,

0 commit comments

Comments
 (0)