Skip to content

Commit 5948238

Browse files
authored
Backend-agnostic quantized_conv_nhwc (channels last)
Differential Revision: D81526626 Pull Request resolved: #13954
1 parent 339e9fc commit 5948238

File tree

3 files changed

+455
-271
lines changed

3 files changed

+455
-271
lines changed

backends/cadence/aot/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ python_library(
129129
],
130130
typing = True,
131131
deps = [
132-
"fbcode//executorch/backends/cadence/aot:utils",
133132
"fbcode//caffe2:torch",
134133
"fbcode//executorch/exir:scalar_type",
135134
],

backends/cadence/aot/ref_implementations.py

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
ScalarType.QINT32: torch.qint32,
2424
}
2525

26-
_Number = bool | int | float
27-
2826

2927
@impl(m, "quantize_per_tensor")
3028
def quantize_per_tensor(
@@ -298,8 +296,7 @@ def quantized_layer_norm_per_tensor(
298296
)
299297

300298

301-
@impl(m, "quantized_conv_nchw")
302-
def quantized_conv_nchw(
299+
def quantized_conv(
303300
input_tensor: torch.Tensor,
304301
weight: torch.Tensor,
305302
bias: torch.Tensor,
@@ -374,6 +371,120 @@ def quantized_conv_nchw(
374371
)
375372

376373

374+
@impl(m, "quantized_conv_nchw")
375+
def quantized_conv_nchw(
376+
input_tensor: torch.Tensor,
377+
weight: torch.Tensor,
378+
bias: torch.Tensor,
379+
stride: tuple[int, int],
380+
padding: tuple[int, int],
381+
dilation: tuple[int, int],
382+
groups: int,
383+
in_zero_point: int,
384+
weight_zero_point: torch.Tensor,
385+
bias_scale: torch.Tensor,
386+
output_scale: float,
387+
output_zero_point: int,
388+
out_multiplier: torch.Tensor,
389+
out_shift: torch.Tensor,
390+
) -> torch.Tensor:
391+
"""
392+
Quantized convolution operation.
393+
394+
Args:
395+
- input_tensor (Tensor): The activations tensor
396+
- weight (Tensor): The weight tensor
397+
- bias (Tensor): The bias tensor
398+
- stride (Tuple[int]): The stride of the convolution
399+
- padding (Tuple[int]): The padding of the convolution
400+
- dilation (Tuple[int]): The dilation of the convolution
401+
- groups (int): The number of groups
402+
- in_zero_point (int): The quantized mapping of zero for the input
403+
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
404+
- bias_scale (Tensor): The quantized bias scale
405+
- output_scale (float): The scale of the output
406+
- output_zero_point (int): The zero point of the output
407+
- out_multiplier (Tensor): Unused
408+
- out_shift (Tensor): Unused
409+
"""
410+
if not input_tensor.is_contiguous(memory_format=torch.contiguous_format):
411+
raise ValueError("Input tensor must be in NCHW format")
412+
return quantized_conv(
413+
input_tensor,
414+
weight,
415+
bias,
416+
stride,
417+
padding,
418+
dilation,
419+
groups,
420+
in_zero_point,
421+
weight_zero_point,
422+
bias_scale,
423+
output_scale,
424+
output_zero_point,
425+
out_multiplier,
426+
out_shift,
427+
)
428+
429+
430+
@impl(m, "quantized_conv_nhwc")
431+
def quantized_conv_nhwc(
432+
input_tensor: torch.Tensor,
433+
weight: torch.Tensor,
434+
bias: torch.Tensor,
435+
stride: tuple[int, int],
436+
padding: tuple[int, int],
437+
dilation: tuple[int, int],
438+
groups: int,
439+
in_zero_point: int,
440+
weight_zero_point: torch.Tensor,
441+
bias_scale: torch.Tensor,
442+
output_scale: float,
443+
output_zero_point: int,
444+
out_multiplier: torch.Tensor,
445+
out_shift: torch.Tensor,
446+
) -> torch.Tensor:
447+
"""
448+
Quantized convolution operation.
449+
450+
Args:
451+
- input_tensor (Tensor): The activations tensor
452+
- weight (Tensor): The weight tensor
453+
- bias (Tensor): The bias tensor
454+
- stride (Tuple[int]): The stride of the convolution
455+
- padding (Tuple[int]): The padding of the convolution
456+
- dilation (Tuple[int]): The dilation of the convolution
457+
- groups (int): The number of groups
458+
- in_zero_point (int): The quantized mapping of zero for the input
459+
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
460+
- bias_scale (Tensor): The quantized bias scale
461+
- output_scale (float): The scale of the output
462+
- output_zero_point (int): The zero point of the output
463+
- out_multiplier (Tensor): Unused
464+
- out_shift (Tensor): Unused
465+
"""
466+
467+
if not input_tensor.is_contiguous(memory_format=torch.channels_last):
468+
raise ValueError("Input tensor must be in NHWC format")
469+
470+
return quantized_conv(
471+
input_tensor,
472+
weight,
473+
bias,
474+
stride,
475+
padding,
476+
dilation,
477+
groups,
478+
in_zero_point,
479+
weight_zero_point,
480+
bias_scale,
481+
output_scale,
482+
output_zero_point,
483+
out_multiplier,
484+
out_shift,
485+
)
486+
487+
377488
@impl(m, "requantize")
378489
def requantize(
379490
input: torch.Tensor,

0 commit comments

Comments
 (0)