Skip to content

Commit 8e208ad

Browse files
authored
Split quantized convolutions into NCHW and NHWC variants
Differential Revision: D79940643 Pull Request resolved: #13383
1 parent 4a1cbac commit 8e208ad

16 files changed

+1712
-1411
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,15 @@
190190
- arg_meta: null
191191
kernel_name: impl::reference::dequantize_per_tensor_out
192192

193-
- func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
193+
- func: cadence::quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
194194
kernels:
195195
- arg_meta: null
196-
kernel_name: impl::reference::quantized_conv_out
196+
kernel_name: impl::reference::quantized_conv_nchw_out
197+
198+
- func: cadence::quantized_conv_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
199+
kernels:
200+
- arg_meta: null
201+
kernel_name: impl::reference::quantized_conv_nhwc_out
197202

198203
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
199204
kernels:
@@ -269,10 +274,15 @@
269274
- arg_meta: null
270275
kernel_name: impl::reference::im2row_per_tensor_out
271276

272-
- func: cadence::quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
277+
- func: cadence::quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
278+
kernels:
279+
- arg_meta: null
280+
kernel_name: impl::reference::quantized_conv_nchw_per_tensor_out
281+
282+
- func: cadence::quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
273283
kernels:
274284
- arg_meta: null
275-
kernel_name: impl::reference::quantized_conv_per_tensor_out
285+
kernel_name: impl::reference::quantized_conv_nhwc_per_tensor_out
276286

277287
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
278288
kernels:

backends/cadence/aot/functions_hifi.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,15 @@
290290
- arg_meta: null
291291
kernel_name: cadence::impl::HiFi::dequantize_per_tensor_out
292292

293-
- func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
293+
- func: cadence::quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
294294
kernels:
295295
- arg_meta: null
296-
kernel_name: cadence::impl::HiFi::quantized_conv_out
296+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_out
297+
298+
- func: cadence::quantized_conv_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
299+
kernels:
300+
- arg_meta: null
301+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_out
297302

298303
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
299304
kernels:

backends/cadence/aot/ops_registrations.py

Lines changed: 114 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,29 @@
8585
)
8686

8787
lib.define(
88-
"quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
88+
"quantized_conv_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
8989
)
9090
lib.define(
91-
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
91+
"quantized_conv_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
9292
)
9393
lib.define(
94-
"quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)"
94+
"quantized_conv_nhwc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
9595
)
9696
lib.define(
97-
"quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
97+
"quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
98+
)
99+
lib.define(
100+
"quantized_conv_nchw(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
101+
)
102+
lib.define(
103+
"quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
104+
)
105+
lib.define(
106+
"quantized_conv_nchw.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
107+
)
108+
lib.define(
109+
"quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
98110
)
99-
100111
lib.define(
101112
"quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
102113
)
@@ -532,8 +543,8 @@ def quantized_linear_asym8uxasym8u_asym8u_per_tensor_meta(
532543
return src.new_empty(out_size, dtype=src.dtype)
533544

534545

535-
@register_fake("cadence::quantized_conv")
536-
def quantized_conv_meta(
546+
@register_fake("cadence::quantized_conv_nhwc")
547+
def quantized_conv_nhwc_meta(
537548
input: torch.Tensor,
538549
weight: torch.Tensor,
539550
bias: torch.Tensor,
@@ -548,12 +559,8 @@ def quantized_conv_meta(
548559
output_zero_point: int,
549560
out_multiplier: torch.Tensor,
550561
out_shift: torch.Tensor,
551-
channel_last: bool = False,
552562
) -> torch.Tensor:
553-
if channel_last:
554-
out_channels, *kernel_size, _ = weight.shape
555-
else:
556-
out_channels, _, *kernel_size = weight.shape
563+
out_channels, *kernel_size, _ = weight.shape
557564

558565
in_size = input.shape
559566
# Assert that the input tensor has at least 3 dimensions, and at most 6
@@ -569,19 +576,63 @@ def quantized_conv_meta(
569576
padding[1],
570577
dilation[1],
571578
kernel_size[0],
572-
channel_last,
579+
True,
573580
)
574581
if len(in_size) == 3
575582
else get_conv2d_output_size(
576-
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
583+
in_size, out_channels, stride, padding, dilation, kernel_size, True
577584
)
578585
)
579586

580587
return input.new_empty(output_size, dtype=input.dtype)
581588

582589

583-
@register_fake("cadence::quantized_conv.per_tensor")
584-
def quantized_conv_per_tensor_meta(
590+
@register_fake("cadence::quantized_conv_nchw")
591+
def quantized_conv_nchw_meta(
592+
input: torch.Tensor,
593+
weight: torch.Tensor,
594+
bias: torch.Tensor,
595+
stride: Tuple[int],
596+
padding: Tuple[int],
597+
dilation: Tuple[int],
598+
groups: int,
599+
in_zero_point: int,
600+
weight_zero_point: torch.Tensor,
601+
bias_scale: torch.Tensor,
602+
output_scale: float,
603+
output_zero_point: int,
604+
out_multiplier: torch.Tensor,
605+
out_shift: torch.Tensor,
606+
) -> torch.Tensor:
607+
out_channels, _, *kernel_size = weight.shape
608+
609+
in_size = input.shape
610+
# Assert that the input tensor has at least 3 dimensions, and at most 6
611+
assert len(in_size) > 2
612+
assert len(in_size) < 6
613+
614+
# Compute the output tensor size
615+
output_size = (
616+
get_conv1d_output_size(
617+
in_size,
618+
out_channels,
619+
stride[1],
620+
padding[1],
621+
dilation[1],
622+
kernel_size[0],
623+
False,
624+
)
625+
if len(in_size) == 3
626+
else get_conv2d_output_size(
627+
in_size, out_channels, stride, padding, dilation, kernel_size, False
628+
)
629+
)
630+
631+
return input.new_empty(output_size, dtype=input.dtype)
632+
633+
634+
@register_fake("cadence::quantized_conv_nchw.per_tensor")
635+
def quantized_conv_nchw_per_tensor_meta(
585636
input: torch.Tensor,
586637
weight: torch.Tensor,
587638
bias: torch.Tensor,
@@ -596,12 +647,8 @@ def quantized_conv_per_tensor_meta(
596647
output_zero_point: int,
597648
out_multiplier: int,
598649
out_shift: int,
599-
channel_last: bool = False,
600650
) -> torch.Tensor:
601-
if channel_last:
602-
out_channels, *kernel_size, _ = weight.shape
603-
else:
604-
out_channels, _, *kernel_size = weight.shape
651+
out_channels, _, *kernel_size = weight.shape
605652

606653
in_size = input.shape
607654
# Assert that the input tensor has at least 3 dimensions, and at most 6
@@ -617,11 +664,55 @@ def quantized_conv_per_tensor_meta(
617664
padding[1],
618665
dilation[1],
619666
kernel_size[0],
620-
channel_last,
667+
False,
621668
)
622669
if len(in_size) == 3
623670
else get_conv2d_output_size(
624-
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
671+
in_size, out_channels, stride, padding, dilation, kernel_size, False
672+
)
673+
)
674+
675+
return input.new_empty(output_size, dtype=input.dtype)
676+
677+
678+
@register_fake("cadence::quantized_conv_nhwc.per_tensor")
679+
def quantized_conv_nhwc_per_tensor_meta(
680+
input: torch.Tensor,
681+
weight: torch.Tensor,
682+
bias: torch.Tensor,
683+
stride: Tuple[int],
684+
padding: Tuple[int],
685+
dilation: Tuple[int],
686+
groups: int,
687+
in_zero_point: int,
688+
weight_zero_point: int,
689+
bias_scale: float,
690+
output_scale: float,
691+
output_zero_point: int,
692+
out_multiplier: int,
693+
out_shift: int,
694+
) -> torch.Tensor:
695+
out_channels, *kernel_size, _ = weight.shape
696+
697+
in_size = input.shape
698+
# Assert that the input tensor has at least 3 dimensions, and at most 6
699+
assert len(in_size) > 2
700+
assert len(in_size) < 6
701+
702+
# Compute the output tensor size
703+
output_size = (
704+
get_conv1d_output_size(
705+
in_size,
706+
out_channels,
707+
stride[1],
708+
padding[1],
709+
dilation[1],
710+
kernel_size[0],
711+
True,
712+
)
713+
if len(in_size) == 3
714+
else get_conv2d_output_size(
715+
in_size, out_channels, stride, padding, dilation, kernel_size, True
625716
)
626717
)
627718

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ def get_args_and_kwargs_conv(
331331
"out_zero_point": quant_node.args[2],
332332
"out_multiplier": out_multiplier_,
333333
"out_shift": out_shift_,
334-
"channel_last": False,
335334
}
336335
return args, kwargs
337336

backends/cadence/aot/quantizer/patterns.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def get_anchors(
247247
)
248248

249249
def replacement_op(self) -> OpOverload:
250-
return torch.ops.cadence.quantized_conv.default
250+
return torch.ops.cadence.quantized_conv_nchw.default
251251

252252

253253
class Conv2dPattern(QuantizationPattern):
@@ -286,7 +286,7 @@ def get_anchors(
286286
)
287287

288288
def replacement_op(self) -> OpOverload:
289-
return torch.ops.cadence.quantized_conv.default
289+
return torch.ops.cadence.quantized_conv_nchw.default
290290

291291

292292
class LayerNormPattern(QuantizationPattern):

0 commit comments

Comments
 (0)