Skip to content

Commit 5ff0208

Browse files
authored
Add support for strongly typed conv_nchw and conv_nhwc
Differential Revision: D80295124 Pull Request resolved: #13462
1 parent ed61348 commit 5ff0208

12 files changed

+2430
-29
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,21 @@
214214
- arg_meta: null
215215
kernel_name: impl::reference::quantized_linear_out
216216

217+
- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
218+
kernels:
219+
- arg_meta: null
220+
kernel_name: impl::reference::quantized_linear_per_tensor_out
221+
222+
- func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
223+
kernels:
224+
- arg_meta: null
225+
kernel_name: impl::reference::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out
226+
227+
- func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
228+
kernels:
229+
- arg_meta: null
230+
kernel_name: impl::reference::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out
231+
217232
- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
218233
kernels:
219234
- arg_meta: null
@@ -249,40 +264,45 @@
249264
- arg_meta: null
250265
kernel_name: impl::reference::quantized_matmul_asym8uxasym8u_asym8u_out
251266

252-
- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
267+
- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
253268
kernels:
254269
- arg_meta: null
255-
kernel_name: impl::reference::quantized_linear_per_tensor_out
270+
kernel_name: impl::reference::im2row_out
256271

257-
- func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
272+
- func: cadence::im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
258273
kernels:
259274
- arg_meta: null
260-
kernel_name: impl::reference::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out
275+
kernel_name: impl::reference::im2row_per_tensor_out
261276

262-
- func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
277+
- func: cadence::quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
263278
kernels:
264279
- arg_meta: null
265-
kernel_name: impl::reference::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out
280+
kernel_name: impl::reference::quantized_conv_nchw_per_tensor_out
266281

267-
- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
282+
- func: cadence::quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
268283
kernels:
269284
- arg_meta: null
270-
kernel_name: impl::reference::im2row_out
285+
kernel_name: impl::reference::quantized_conv_nhwc_per_tensor_out
271286

272-
- func: cadence::im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
287+
- func: cadence::quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
273288
kernels:
274289
- arg_meta: null
275-
kernel_name: impl::reference::im2row_per_tensor_out
290+
kernel_name: impl::reference::quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out
276291

277-
- func: cadence::quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
292+
- func: cadence::quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
278293
kernels:
279294
- arg_meta: null
280-
kernel_name: impl::reference::quantized_conv_nchw_per_tensor_out
295+
kernel_name: impl::reference::quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out
281296

282-
- func: cadence::quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
297+
- func: cadence::quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
283298
kernels:
284299
- arg_meta: null
285-
kernel_name: impl::reference::quantized_conv_nhwc_per_tensor_out
300+
kernel_name: impl::reference::quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out
301+
302+
- func: cadence::quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
303+
kernels:
304+
- arg_meta: null
305+
kernel_name: impl::reference::quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out
286306

287307
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
288308
kernels:

backends/cadence/aot/functions_hifi.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,36 @@
300300
- arg_meta: null
301301
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_out
302302

303+
- func: cadence::quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
304+
kernels:
305+
- arg_meta: null
306+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_per_tensor_out
307+
308+
- func: cadence::quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
309+
kernels:
310+
- arg_meta: null
311+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_per_tensor_out
312+
313+
- func: cadence::quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
314+
kernels:
315+
- arg_meta: null
316+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out
317+
318+
- func: cadence::quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
319+
kernels:
320+
- arg_meta: null
321+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out
322+
323+
- func: cadence::quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
324+
kernels:
325+
- arg_meta: null
326+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out
327+
328+
- func: cadence::quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
329+
kernels:
330+
- arg_meta: null
331+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out
332+
303333
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
304334
kernels:
305335
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,30 @@
120120
lib.define(
121121
"quantized_matmul_asym8sxasym8s_asym8s.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
122122
)
123+
lib.define(
124+
"quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
125+
)
126+
lib.define(
127+
"quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
128+
)
129+
lib.define(
130+
"quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
131+
)
132+
lib.define(
133+
"quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
134+
)
135+
lib.define(
136+
"quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
137+
)
138+
lib.define(
139+
"quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
140+
)
141+
lib.define(
142+
"quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
143+
)
144+
lib.define(
145+
"quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
146+
)
123147
lib.define(
124148
"quantized_matmul_asym8uxasym8u_asym8u(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
125149
)
@@ -719,6 +743,182 @@ def quantized_conv_nhwc_per_tensor_meta(
719743
return input.new_empty(output_size, dtype=input.dtype)
720744

721745

746+
@register_fake("cadence::quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor")
747+
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta(
748+
input: torch.Tensor,
749+
weight: torch.Tensor,
750+
bias: torch.Tensor,
751+
stride: Tuple[int],
752+
padding: Tuple[int],
753+
dilation: Tuple[int],
754+
groups: int,
755+
in_zero_point: int,
756+
weight_zero_point: int,
757+
bias_scale: float,
758+
output_scale: float,
759+
output_zero_point: int,
760+
out_multiplier: int,
761+
out_shift: int,
762+
) -> torch.Tensor:
763+
out_channels, _, *kernel_size = weight.shape
764+
765+
in_size = input.shape
766+
# Assert that the input tensor has at least 3 dimensions, and at most 6
767+
assert len(in_size) > 2
768+
assert len(in_size) < 6
769+
770+
# Compute the output tensor size
771+
output_size = (
772+
get_conv1d_output_size(
773+
in_size,
774+
out_channels,
775+
stride[1],
776+
padding[1],
777+
dilation[1],
778+
kernel_size[0],
779+
False,
780+
)
781+
if len(in_size) == 3
782+
else get_conv2d_output_size(
783+
in_size, out_channels, stride, padding, dilation, kernel_size, False
784+
)
785+
)
786+
787+
return input.new_empty(output_size, dtype=input.dtype)
788+
789+
790+
@register_fake("cadence::quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor")
791+
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta(
792+
input: torch.Tensor,
793+
weight: torch.Tensor,
794+
bias: torch.Tensor,
795+
stride: Tuple[int],
796+
padding: Tuple[int],
797+
dilation: Tuple[int],
798+
groups: int,
799+
in_zero_point: int,
800+
weight_zero_point: int,
801+
bias_scale: float,
802+
output_scale: float,
803+
output_zero_point: int,
804+
out_multiplier: int,
805+
out_shift: int,
806+
) -> torch.Tensor:
807+
out_channels, _, *kernel_size = weight.shape
808+
809+
in_size = input.shape
810+
# Assert that the input tensor has at least 3 dimensions, and at most 6
811+
assert len(in_size) > 2
812+
assert len(in_size) < 6
813+
814+
# Compute the output tensor size
815+
output_size = (
816+
get_conv1d_output_size(
817+
in_size,
818+
out_channels,
819+
stride[1],
820+
padding[1],
821+
dilation[1],
822+
kernel_size[0],
823+
False,
824+
)
825+
if len(in_size) == 3
826+
else get_conv2d_output_size(
827+
in_size, out_channels, stride, padding, dilation, kernel_size, False
828+
)
829+
)
830+
831+
return input.new_empty(output_size, dtype=input.dtype)
832+
833+
834+
@register_fake("cadence::quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor")
835+
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
836+
input: torch.Tensor,
837+
weight: torch.Tensor,
838+
bias: torch.Tensor,
839+
stride: Tuple[int],
840+
padding: Tuple[int],
841+
dilation: Tuple[int],
842+
groups: int,
843+
in_zero_point: int,
844+
weight_zero_point: int,
845+
bias_scale: float,
846+
output_scale: float,
847+
output_zero_point: int,
848+
out_multiplier: int,
849+
out_shift: int,
850+
) -> torch.Tensor:
851+
out_channels, *kernel_size, _ = weight.shape
852+
853+
in_size = input.shape
854+
# Assert that the input tensor has at least 3 dimensions, and at most 6
855+
assert len(in_size) > 2
856+
assert len(in_size) < 6
857+
858+
# Compute the output tensor size
859+
output_size = (
860+
get_conv1d_output_size(
861+
in_size,
862+
out_channels,
863+
stride[1],
864+
padding[1],
865+
dilation[1],
866+
kernel_size[0],
867+
True,
868+
)
869+
if len(in_size) == 3
870+
else get_conv2d_output_size(
871+
in_size, out_channels, stride, padding, dilation, kernel_size, True
872+
)
873+
)
874+
875+
return input.new_empty(output_size, dtype=input.dtype)
876+
877+
878+
@register_fake("cadence::quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor")
879+
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
880+
input: torch.Tensor,
881+
weight: torch.Tensor,
882+
bias: torch.Tensor,
883+
stride: Tuple[int],
884+
padding: Tuple[int],
885+
dilation: Tuple[int],
886+
groups: int,
887+
in_zero_point: int,
888+
weight_zero_point: int,
889+
bias_scale: float,
890+
output_scale: float,
891+
output_zero_point: int,
892+
out_multiplier: int,
893+
out_shift: int,
894+
) -> torch.Tensor:
895+
out_channels, *kernel_size, _ = weight.shape
896+
897+
in_size = input.shape
898+
# Assert that the input tensor has at least 3 dimensions, and at most 6
899+
assert len(in_size) > 2
900+
assert len(in_size) < 6
901+
902+
# Compute the output tensor size
903+
output_size = (
904+
get_conv1d_output_size(
905+
in_size,
906+
out_channels,
907+
stride[1],
908+
padding[1],
909+
dilation[1],
910+
kernel_size[0],
911+
True,
912+
)
913+
if len(in_size) == 3
914+
else get_conv2d_output_size(
915+
in_size, out_channels, stride, padding, dilation, kernel_size, True
916+
)
917+
)
918+
919+
return input.new_empty(output_size, dtype=input.dtype)
920+
921+
722922
@register_fake("cadence::quantized_layer_norm")
723923
def quantized_layer_norm_meta(
724924
input: torch.Tensor,

0 commit comments

Comments
 (0)