Skip to content

Commit 651b357

Browse files
authored
Cadence: Support quantized_w8a32_conv
Differential Revision: D84658444 Pull Request resolved: pytorch#15137
1 parent 93f19db commit 651b357

File tree

3 files changed

+242
-2
lines changed

3 files changed

+242
-2
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def _validate_ref_impl_exists() -> None:
6767
"cadence::dequantize_per_tensor_asym16u",
6868
"cadence::linalg_vector_norm",
6969
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
70-
"cadence::quantized_w8a32_conv",
7170
"cadence::quantize_per_tensor_asym32s",
7271
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
7372
"cadence::linalg_svd",
@@ -2753,7 +2752,10 @@ def quantized_w8a32_conv_meta(
27532752
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
27542753
assert len(src.shape) == 3
27552754

2756-
kernel_size, out_channels, in_channels = weight.shape
2755+
out_channels, in_channels, kernel_size = weight.shape
2756+
assert kernel_size == 3
2757+
assert (out_channels % 4) == 0
2758+
assert (in_channels % 4) == 0
27572759
assert in_channels == src.shape[-1]
27582760

27592761
# Compute the output tensor size

backends/cadence/aot/ref_implementations.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,48 @@ def quantized_conv2d_nchw_per_tensor(
703703
)
704704

705705

706+
@impl_tracked(m, "quantized_w8a32_conv")
707+
def quantized_w8a32_conv(
708+
src: torch.Tensor,
709+
weight: torch.Tensor,
710+
w_scale: float,
711+
bias: torch.Tensor,
712+
b_scale: float,
713+
) -> torch.Tensor:
714+
715+
if len(weight.shape) != 3:
716+
raise ValueError("Weight tensor must be 3D")
717+
718+
out_channels, in_channels, kernel_size = weight.shape
719+
if kernel_size != 3:
720+
raise ValueError("Kernel size must be 3")
721+
if (out_channels % 4) != 0:
722+
raise ValueError("Out channels must be a multiple of 4")
723+
if (in_channels % 4) != 0:
724+
raise ValueError("In channels must be a multiple of 4")
725+
726+
# src comes in shape [batch, in_channel, in_length]
727+
# weight comes in shape [out_ch, in_ch, kernel_dim]
728+
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
729+
# Dequantize weight using scale
730+
dequant_weight = weight.float() * w_scale
731+
732+
# Dequantize bias using scale
733+
dequant_bias = bias.float() * b_scale
734+
735+
# Perform 1D convolution
736+
# src: [batch, in_channel, in_length]
737+
# weight: [out_ch, in_ch, kernel_dim]
738+
# bias: [out_ch]
739+
output = torch.nn.functional.conv1d(
740+
src.float(),
741+
dequant_weight,
742+
dequant_bias,
743+
)
744+
745+
return output
746+
747+
706748
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
707749
def quantized_conv2d_nhwc_per_tensor(
708750
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,202 @@ def test_quantized_conv_per_tensor(
10401040
f"Output values don't match expected. Got {output}, expected {expected_output}",
10411041
)
10421042

1043+
@expand(
1044+
[
1045+
(
1046+
"basic_int8_weights",
1047+
torch.tensor(
1048+
[
1049+
[
1050+
[1.0, 2.0, 3.0, 4.0, 5.0],
1051+
[1.0, 2.0, 3.0, 4.0, 5.0],
1052+
[1.0, 2.0, 3.0, 4.0, 5.0],
1053+
[1.0, 2.0, 3.0, 4.0, 5.0],
1054+
]
1055+
],
1056+
dtype=torch.float32,
1057+
), # src: 1x4x5
1058+
torch.tensor(
1059+
[
1060+
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
1061+
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
1062+
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
1063+
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
1064+
],
1065+
dtype=torch.int8,
1066+
), # weight: 4x4x3
1067+
0.1, # w_scale
1068+
torch.tensor([1, 1, 1, 1], dtype=torch.int8), # bias: 4
1069+
0.2, # b_scale
1070+
torch.tensor(
1071+
[
1072+
[
1073+
[2.2, 3.0, 3.8],
1074+
[2.2, 3.0, 3.8],
1075+
[2.2, 3.0, 3.8],
1076+
[2.2, 3.0, 3.8],
1077+
]
1078+
],
1079+
dtype=torch.float32,
1080+
), # expected: conv1d result
1081+
),
1082+
(
1083+
"batch_size_2",
1084+
torch.tensor(
1085+
[
1086+
[
1087+
[1.0, 2.0, 3.0, 4.0, 5.0],
1088+
[1.0, 2.0, 3.0, 4.0, 5.0],
1089+
[1.0, 2.0, 3.0, 4.0, 5.0],
1090+
[1.0, 2.0, 3.0, 4.0, 5.0],
1091+
],
1092+
[
1093+
[2.0, 3.0, 4.0, 5.0, 6.0],
1094+
[2.0, 3.0, 4.0, 5.0, 6.0],
1095+
[2.0, 3.0, 4.0, 5.0, 6.0],
1096+
[2.0, 3.0, 4.0, 5.0, 6.0],
1097+
],
1098+
],
1099+
dtype=torch.float32,
1100+
), # src: 2x4x5
1101+
torch.tensor(
1102+
[
1103+
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
1104+
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
1105+
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
1106+
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
1107+
],
1108+
dtype=torch.int8,
1109+
), # weight: 4x4x3
1110+
1.0, # w_scale
1111+
torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4
1112+
1.0, # b_scale
1113+
torch.tensor(
1114+
[
1115+
[
1116+
[24.0, 36.0, 48.0],
1117+
[24.0, 36.0, 48.0],
1118+
[24.0, 36.0, 48.0],
1119+
[24.0, 36.0, 48.0],
1120+
],
1121+
[
1122+
[36.0, 48.0, 60.0],
1123+
[36.0, 48.0, 60.0],
1124+
[36.0, 48.0, 60.0],
1125+
[36.0, 48.0, 60.0],
1126+
],
1127+
],
1128+
dtype=torch.float32,
1129+
), # expected
1130+
),
1131+
(
1132+
"zero_weights_bias",
1133+
torch.tensor(
1134+
[
1135+
[
1136+
[1.0, 2.0, 3.0, 4.0, 5.0],
1137+
[1.0, 2.0, 3.0, 4.0, 5.0],
1138+
[1.0, 2.0, 3.0, 4.0, 5.0],
1139+
[1.0, 2.0, 3.0, 4.0, 5.0],
1140+
]
1141+
],
1142+
dtype=torch.float32,
1143+
), # src: 1x4x5
1144+
torch.tensor(
1145+
[
1146+
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
1147+
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
1148+
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
1149+
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
1150+
],
1151+
dtype=torch.int8,
1152+
), # weight: 4x4x3
1153+
0.1, # w_scale
1154+
torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4
1155+
1.0, # b_scale
1156+
torch.tensor(
1157+
[
1158+
[
1159+
[0.0, 0.0, 0.0],
1160+
[0.0, 0.0, 0.0],
1161+
[0.0, 0.0, 0.0],
1162+
[0.0, 0.0, 0.0],
1163+
]
1164+
],
1165+
dtype=torch.float32,
1166+
), # expected
1167+
),
1168+
(
1169+
"negative_weights",
1170+
torch.tensor(
1171+
[
1172+
[
1173+
[2.0, 4.0, 6.0, 8.0, 10.0],
1174+
[2.0, 4.0, 6.0, 8.0, 10.0],
1175+
[2.0, 4.0, 6.0, 8.0, 10.0],
1176+
[2.0, 4.0, 6.0, 8.0, 10.0],
1177+
]
1178+
],
1179+
dtype=torch.float32,
1180+
), # src: 1x4x5
1181+
torch.tensor(
1182+
[
1183+
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
1184+
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
1185+
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
1186+
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
1187+
],
1188+
dtype=torch.int8,
1189+
), # weight: 4x4x3
1190+
0.5, # w_scale
1191+
torch.tensor([2, 2, 2, 2], dtype=torch.float32), # bias: 4
1192+
1.0, # b_scale
1193+
torch.tensor(
1194+
[
1195+
[
1196+
[-14.0, -26.0, -38.0],
1197+
[-14.0, -26.0, -38.0],
1198+
[-14.0, -26.0, -38.0],
1199+
[-14.0, -26.0, -38.0],
1200+
]
1201+
],
1202+
dtype=torch.float32,
1203+
), # expected
1204+
),
1205+
]
1206+
)
1207+
def test_quantized_w8a32_conv(
1208+
self,
1209+
name: str,
1210+
src: torch.Tensor,
1211+
weight: torch.Tensor,
1212+
w_scale: float,
1213+
bias: torch.Tensor,
1214+
b_scale: float,
1215+
expected_output: torch.Tensor,
1216+
) -> None:
1217+
output = torch.ops.cadence.quantized_w8a32_conv(
1218+
src, weight, w_scale, bias, b_scale
1219+
)
1220+
1221+
# Verify output properties
1222+
self.assertEqual(
1223+
output.dtype,
1224+
torch.float32,
1225+
f"Output dtype should be float32 in {name}",
1226+
)
1227+
self.assertEqual(
1228+
output.shape,
1229+
expected_output.shape,
1230+
f"Output shape should match expected shape in {name}",
1231+
)
1232+
1233+
# Verify output matches expected values
1234+
self.assertTrue(
1235+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1236+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1237+
)
1238+
10431239
@expand(
10441240
[
10451241
# Test case 1: Basic int8 case with negative scale

0 commit comments

Comments
 (0)