Skip to content

Commit 201beda

Browse files
authored
Split on dilation for strongly typed convs
Differential Revision: D80574314 Pull Request resolved: #13532
1 parent ea0fff3 commit 201beda

16 files changed

+1315
-759
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,26 @@
304304
- arg_meta: null
305305
kernel_name: impl::reference::quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out
306306

307+
- func: cadence::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
308+
kernels:
309+
- arg_meta: null
310+
kernel_name: impl::reference::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out
311+
312+
- func: cadence::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
313+
kernels:
314+
- arg_meta: null
315+
kernel_name: impl::reference::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out
316+
317+
- func: cadence::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
318+
kernels:
319+
- arg_meta: null
320+
kernel_name: impl::reference::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out
321+
322+
- func: cadence::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
323+
kernels:
324+
- arg_meta: null
325+
kernel_name: impl::reference::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out
326+
307327
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
308328
kernels:
309329
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,26 @@
330330
- arg_meta: null
331331
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out
332332

333+
- func: cadence::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
334+
kernels:
335+
- arg_meta: null
336+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out
337+
338+
- func: cadence::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
339+
kernels:
340+
- arg_meta: null
341+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out
342+
343+
- func: cadence::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
344+
kernels:
345+
- arg_meta: null
346+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out
347+
348+
- func: cadence::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
349+
kernels:
350+
- arg_meta: null
351+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out
352+
333353
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
334354
kernels:
335355
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,30 @@
144144
lib.define(
145145
"quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
146146
)
147+
lib.define(
148+
"quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
149+
)
150+
lib.define(
151+
"quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
152+
)
153+
lib.define(
154+
"quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
155+
)
156+
lib.define(
157+
"quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
158+
)
159+
lib.define(
160+
"quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
161+
)
162+
lib.define(
163+
"quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
164+
)
165+
lib.define(
166+
"quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
167+
)
168+
lib.define(
169+
"quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
170+
)
147171
lib.define(
148172
"quantized_matmul_asym8uxasym8u_asym8u(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
149173
)
@@ -919,6 +943,182 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
919943
return input.new_empty(output_size, dtype=input.dtype)
920944

921945

946+
@register_fake("cadence::quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
947+
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
948+
input: torch.Tensor,
949+
weight: torch.Tensor,
950+
bias: torch.Tensor,
951+
stride: Tuple[int],
952+
padding: Tuple[int],
953+
dilation: Tuple[int],
954+
groups: int,
955+
in_zero_point: int,
956+
weight_zero_point: int,
957+
bias_scale: float,
958+
output_scale: float,
959+
output_zero_point: int,
960+
out_multiplier: int,
961+
out_shift: int,
962+
) -> torch.Tensor:
963+
out_channels, _, *kernel_size = weight.shape
964+
965+
in_size = input.shape
966+
# Assert that the input tensor has at least 3 dimensions, and at most 6
967+
assert len(in_size) > 2
968+
assert len(in_size) < 6
969+
970+
# Compute the output tensor size
971+
output_size = (
972+
get_conv1d_output_size(
973+
in_size,
974+
out_channels,
975+
stride[1],
976+
padding[1],
977+
dilation[1],
978+
kernel_size[0],
979+
False,
980+
)
981+
if len(in_size) == 3
982+
else get_conv2d_output_size(
983+
in_size, out_channels, stride, padding, dilation, kernel_size, False
984+
)
985+
)
986+
987+
return input.new_empty(output_size, dtype=input.dtype)
988+
989+
990+
@register_fake("cadence::quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
991+
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
992+
input: torch.Tensor,
993+
weight: torch.Tensor,
994+
bias: torch.Tensor,
995+
stride: Tuple[int],
996+
padding: Tuple[int],
997+
dilation: Tuple[int],
998+
groups: int,
999+
in_zero_point: int,
1000+
weight_zero_point: int,
1001+
bias_scale: float,
1002+
output_scale: float,
1003+
output_zero_point: int,
1004+
out_multiplier: int,
1005+
out_shift: int,
1006+
) -> torch.Tensor:
1007+
out_channels, _, *kernel_size = weight.shape
1008+
1009+
in_size = input.shape
1010+
# Assert that the input tensor has at least 3 dimensions, and at most 6
1011+
assert len(in_size) > 2
1012+
assert len(in_size) < 6
1013+
1014+
# Compute the output tensor size
1015+
output_size = (
1016+
get_conv1d_output_size(
1017+
in_size,
1018+
out_channels,
1019+
stride[1],
1020+
padding[1],
1021+
dilation[1],
1022+
kernel_size[0],
1023+
False,
1024+
)
1025+
if len(in_size) == 3
1026+
else get_conv2d_output_size(
1027+
in_size, out_channels, stride, padding, dilation, kernel_size, False
1028+
)
1029+
)
1030+
1031+
return input.new_empty(output_size, dtype=input.dtype)
1032+
1033+
1034+
@register_fake("cadence::quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
1035+
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
1036+
input: torch.Tensor,
1037+
weight: torch.Tensor,
1038+
bias: torch.Tensor,
1039+
stride: Tuple[int],
1040+
padding: Tuple[int],
1041+
dilation: Tuple[int],
1042+
groups: int,
1043+
in_zero_point: int,
1044+
weight_zero_point: int,
1045+
bias_scale: float,
1046+
output_scale: float,
1047+
output_zero_point: int,
1048+
out_multiplier: int,
1049+
out_shift: int,
1050+
) -> torch.Tensor:
1051+
out_channels, *kernel_size, _ = weight.shape
1052+
1053+
in_size = input.shape
1054+
# Assert that the input tensor has at least 3 dimensions, and at most 6
1055+
assert len(in_size) > 2
1056+
assert len(in_size) < 6
1057+
1058+
# Compute the output tensor size
1059+
output_size = (
1060+
get_conv1d_output_size(
1061+
in_size,
1062+
out_channels,
1063+
stride[1],
1064+
padding[1],
1065+
dilation[1],
1066+
kernel_size[0],
1067+
True,
1068+
)
1069+
if len(in_size) == 3
1070+
else get_conv2d_output_size(
1071+
in_size, out_channels, stride, padding, dilation, kernel_size, True
1072+
)
1073+
)
1074+
1075+
return input.new_empty(output_size, dtype=input.dtype)
1076+
1077+
1078+
@register_fake("cadence::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
1079+
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
1080+
input: torch.Tensor,
1081+
weight: torch.Tensor,
1082+
bias: torch.Tensor,
1083+
stride: Tuple[int],
1084+
padding: Tuple[int],
1085+
dilation: Tuple[int],
1086+
groups: int,
1087+
in_zero_point: int,
1088+
weight_zero_point: int,
1089+
bias_scale: float,
1090+
output_scale: float,
1091+
output_zero_point: int,
1092+
out_multiplier: int,
1093+
out_shift: int,
1094+
) -> torch.Tensor:
1095+
out_channels, *kernel_size, _ = weight.shape
1096+
1097+
in_size = input.shape
1098+
# Assert that the input tensor has at least 3 dimensions, and at most 6
1099+
assert len(in_size) > 2
1100+
assert len(in_size) < 6
1101+
1102+
# Compute the output tensor size
1103+
output_size = (
1104+
get_conv1d_output_size(
1105+
in_size,
1106+
out_channels,
1107+
stride[1],
1108+
padding[1],
1109+
dilation[1],
1110+
kernel_size[0],
1111+
True,
1112+
)
1113+
if len(in_size) == 3
1114+
else get_conv2d_output_size(
1115+
in_size, out_channels, stride, padding, dilation, kernel_size, True
1116+
)
1117+
)
1118+
1119+
return input.new_empty(output_size, dtype=input.dtype)
1120+
1121+
9221122
@register_fake("cadence::quantized_layer_norm")
9231123
def quantized_layer_norm_meta(
9241124
input: torch.Tensor,

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,107 @@ def test_uint8_dispatch_quantized_conv_nhwc(self) -> None:
341341
),
342342
1,
343343
)
344+
345+
def test_int8_dispatch_quantized_conv_nchw_dilated(self) -> None:
346+
"""Test int8 x int8 inputs with dilation should dispatch to dilated_asym8sxasym8s_asym8s variant for quantized_conv_nchw_dilated"""
347+
x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8)
348+
w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8)
349+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
350+
gm = single_op_builder(
351+
placeholders=(x, w, b),
352+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
353+
args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
354+
)
355+
p = CompileTimeTypeDispatchPass()
356+
gm = cast(PassResult, p(gm)).graph_module
357+
# Original op should be replaced
358+
self.assertEqual(
359+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
360+
0,
361+
)
362+
# Should be replaced with int8 specific variant
363+
self.assertEqual(
364+
count_node(
365+
gm,
366+
exir_ops.edge.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
367+
),
368+
1,
369+
)
370+
371+
def test_uint8_dispatch_quantized_conv_nchw_dilated(self) -> None:
372+
"""Test uint8 x uint8 inputs with dilation should dispatch to dilated_asym8uxasym8u_asym8u variant for quantized_conv_nchw"""
373+
x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8)
374+
w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8)
375+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
376+
gm = single_op_builder(
377+
placeholders=(x, w, b),
378+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
379+
args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
380+
)
381+
p = CompileTimeTypeDispatchPass()
382+
gm = cast(PassResult, p(gm)).graph_module
383+
# Original op should be replaced
384+
self.assertEqual(
385+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
386+
0,
387+
)
388+
# Should be replaced with uint8 specific variant
389+
self.assertEqual(
390+
count_node(
391+
gm,
392+
exir_ops.edge.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
393+
),
394+
1,
395+
)
396+
397+
def test_int8_dispatch_quantized_conv_nhwc_dilated(self) -> None:
398+
"""Test int8 x int8 inputs with dilation should dispatch to dilated_asym8sxasym8s_asym8s variant for quantized_conv_nhwc"""
399+
x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8)
400+
w = torch.randint(-128, 127, (16, 3, 3, 3), dtype=torch.int8)
401+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
402+
gm = single_op_builder(
403+
placeholders=(x, w, b),
404+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
405+
args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
406+
)
407+
p = CompileTimeTypeDispatchPass()
408+
gm = cast(PassResult, p(gm)).graph_module
409+
# Original op should be replaced
410+
self.assertEqual(
411+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
412+
0,
413+
)
414+
# Should be replaced with int8 specific variant
415+
self.assertEqual(
416+
count_node(
417+
gm,
418+
exir_ops.edge.cadence.quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor,
419+
),
420+
1,
421+
)
422+
423+
def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None:
424+
"""Test uint8 x uint8 inputs with dilation should dispatch to dilated_asym8uxasym8u_asym8u variant for quantized_conv_nhwc"""
425+
x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8)
426+
w = torch.randint(0, 255, (16, 3, 3, 3), dtype=torch.uint8)
427+
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
428+
gm = single_op_builder(
429+
placeholders=(x, w, b),
430+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
431+
args=(x, w, b, [1, 1], [0, 0], [2, 2], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
432+
)
433+
p = CompileTimeTypeDispatchPass()
434+
gm = cast(PassResult, p(gm)).graph_module
435+
# Original op should be replaced
436+
self.assertEqual(
437+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
438+
0,
439+
)
440+
# Should be replaced with uint8 specific variant
441+
self.assertEqual(
442+
count_node(
443+
gm,
444+
exir_ops.edge.cadence.quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor,
445+
),
446+
1,
447+
)

backends/cadence/aot/type_dispatch.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,20 @@ def call_operator(
112112
raise RuntimeError(f"Unsupported input types for {op}: {dtype_key}")
113113

114114
type_suffix = config.type_dispatch_suffixes[dtype_key]
115-
typed_op_name = f"{config.base_name}_{type_suffix}"
115+
base_name = config.base_name
116+
117+
if op in [
118+
exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
119+
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
120+
]:
121+
dilation = args[5]
122+
# pyre-ignore[16]: None has no attribute '__iter__'.
123+
is_dilated = any(d > 1 for d in dilation)
124+
125+
if is_dilated:
126+
type_suffix = f"dilated_{type_suffix}"
127+
128+
typed_op_name = f"{base_name}_{type_suffix}"
116129

117130
typed_op = getattr(
118131
getattr(exir_ops.edge.cadence, typed_op_name), config.variant

0 commit comments

Comments
 (0)