Skip to content

Commit e4ddf69

Browse files
authored
Split on depthwise for strongly typed convs
Differential Revision: D80636815 Pull Request resolved: #13562
1 parent 4598fdb commit e4ddf69

16 files changed

+1322
-292
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,26 @@
339339
- arg_meta: null
340340
kernel_name: impl::reference::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out
341341

342+
- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
343+
kernels:
344+
- arg_meta: null
345+
kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out
346+
347+
- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
348+
kernels:
349+
- arg_meta: null
350+
kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out
351+
352+
- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
353+
kernels:
354+
- arg_meta: null
355+
kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out
356+
357+
- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
358+
kernels:
359+
- arg_meta: null
360+
kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out
361+
342362
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
343363
kernels:
344364
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,26 @@
350350
- arg_meta: null
351351
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out
352352

353+
- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
354+
kernels:
355+
- arg_meta: null
356+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out
357+
358+
- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
359+
kernels:
360+
- arg_meta: null
361+
kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out
362+
363+
- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
364+
kernels:
365+
- arg_meta: null
366+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out
367+
368+
- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
369+
kernels:
370+
- arg_meta: null
371+
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out
372+
353373
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
354374
kernels:
355375
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,30 @@
168168
lib.define(
169169
"quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
170170
)
171+
lib.define(
172+
"quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
173+
)
174+
lib.define(
175+
"quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
176+
)
177+
lib.define(
178+
"quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
179+
)
180+
lib.define(
181+
"quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
182+
)
183+
lib.define(
184+
"quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
185+
)
186+
lib.define(
187+
"quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
188+
)
189+
lib.define(
190+
"quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
191+
)
192+
lib.define(
193+
"quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
194+
)
171195
lib.define(
172196
"quantized_matmul_asym8uxasym8u_asym8u(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
173197
)
@@ -1165,6 +1189,182 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
11651189
return input.new_empty(output_size, dtype=input.dtype)
11661190

11671191

1192+
@register_fake("cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
1193+
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
1194+
input: torch.Tensor,
1195+
weight: torch.Tensor,
1196+
bias: torch.Tensor,
1197+
stride: Tuple[int],
1198+
padding: Tuple[int],
1199+
dilation: Tuple[int],
1200+
groups: int,
1201+
in_zero_point: int,
1202+
weight_zero_point: int,
1203+
bias_scale: float,
1204+
output_scale: float,
1205+
output_zero_point: int,
1206+
out_multiplier: int,
1207+
out_shift: int,
1208+
) -> torch.Tensor:
1209+
out_channels, _, *kernel_size = weight.shape
1210+
1211+
in_size = input.shape
1212+
# Assert that the input tensor has at least 3 dimensions, and at most 6
1213+
assert len(in_size) > 2
1214+
assert len(in_size) < 6
1215+
1216+
# Compute the output tensor size
1217+
output_size = (
1218+
get_conv1d_output_size(
1219+
in_size,
1220+
out_channels,
1221+
stride[1],
1222+
padding[1],
1223+
dilation[1],
1224+
kernel_size[0],
1225+
False,
1226+
)
1227+
if len(in_size) == 3
1228+
else get_conv2d_output_size(
1229+
in_size, out_channels, stride, padding, dilation, kernel_size, False
1230+
)
1231+
)
1232+
1233+
return input.new_empty(output_size, dtype=input.dtype)
1234+
1235+
1236+
@register_fake("cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
1237+
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
1238+
input: torch.Tensor,
1239+
weight: torch.Tensor,
1240+
bias: torch.Tensor,
1241+
stride: Tuple[int],
1242+
padding: Tuple[int],
1243+
dilation: Tuple[int],
1244+
groups: int,
1245+
in_zero_point: int,
1246+
weight_zero_point: int,
1247+
bias_scale: float,
1248+
output_scale: float,
1249+
output_zero_point: int,
1250+
out_multiplier: int,
1251+
out_shift: int,
1252+
) -> torch.Tensor:
1253+
out_channels, _, *kernel_size = weight.shape
1254+
1255+
in_size = input.shape
1256+
# Assert that the input tensor has at least 3 dimensions, and at most 6
1257+
assert len(in_size) > 2
1258+
assert len(in_size) < 6
1259+
1260+
# Compute the output tensor size
1261+
output_size = (
1262+
get_conv1d_output_size(
1263+
in_size,
1264+
out_channels,
1265+
stride[1],
1266+
padding[1],
1267+
dilation[1],
1268+
kernel_size[0],
1269+
False,
1270+
)
1271+
if len(in_size) == 3
1272+
else get_conv2d_output_size(
1273+
in_size, out_channels, stride, padding, dilation, kernel_size, False
1274+
)
1275+
)
1276+
1277+
return input.new_empty(output_size, dtype=input.dtype)
1278+
1279+
1280+
@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
1281+
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
1282+
input: torch.Tensor,
1283+
weight: torch.Tensor,
1284+
bias: torch.Tensor,
1285+
stride: Tuple[int],
1286+
padding: Tuple[int],
1287+
dilation: Tuple[int],
1288+
groups: int,
1289+
in_zero_point: int,
1290+
weight_zero_point: int,
1291+
bias_scale: float,
1292+
output_scale: float,
1293+
output_zero_point: int,
1294+
out_multiplier: int,
1295+
out_shift: int,
1296+
) -> torch.Tensor:
1297+
out_channels, *kernel_size, _ = weight.shape
1298+
1299+
in_size = input.shape
1300+
# Assert that the input tensor has at least 3 dimensions, and at most 6
1301+
assert len(in_size) > 2
1302+
assert len(in_size) < 6
1303+
1304+
# Compute the output tensor size
1305+
output_size = (
1306+
get_conv1d_output_size(
1307+
in_size,
1308+
out_channels,
1309+
stride[1],
1310+
padding[1],
1311+
dilation[1],
1312+
kernel_size[0],
1313+
True,
1314+
)
1315+
if len(in_size) == 3
1316+
else get_conv2d_output_size(
1317+
in_size, out_channels, stride, padding, dilation, kernel_size, True
1318+
)
1319+
)
1320+
1321+
return input.new_empty(output_size, dtype=input.dtype)
1322+
1323+
1324+
@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
1325+
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
1326+
input: torch.Tensor,
1327+
weight: torch.Tensor,
1328+
bias: torch.Tensor,
1329+
stride: Tuple[int],
1330+
padding: Tuple[int],
1331+
dilation: Tuple[int],
1332+
groups: int,
1333+
in_zero_point: int,
1334+
weight_zero_point: int,
1335+
bias_scale: float,
1336+
output_scale: float,
1337+
output_zero_point: int,
1338+
out_multiplier: int,
1339+
out_shift: int,
1340+
) -> torch.Tensor:
1341+
out_channels, *kernel_size, _ = weight.shape
1342+
1343+
in_size = input.shape
1344+
# Assert that the input tensor has at least 3 dimensions, and at most 6
1345+
assert len(in_size) > 2
1346+
assert len(in_size) < 6
1347+
1348+
# Compute the output tensor size
1349+
output_size = (
1350+
get_conv1d_output_size(
1351+
in_size,
1352+
out_channels,
1353+
stride[1],
1354+
padding[1],
1355+
dilation[1],
1356+
kernel_size[0],
1357+
True,
1358+
)
1359+
if len(in_size) == 3
1360+
else get_conv2d_output_size(
1361+
in_size, out_channels, stride, padding, dilation, kernel_size, True
1362+
)
1363+
)
1364+
1365+
return input.new_empty(output_size, dtype=input.dtype)
1366+
1367+
11681368
@register_fake("cadence::quantized_layer_norm")
11691369
def quantized_layer_norm_meta(
11701370
input: torch.Tensor,

0 commit comments

Comments
 (0)