@@ -1040,6 +1040,202 @@ def test_quantized_conv_per_tensor(
10401040 f"Output values don't match expected. Got { output } , expected { expected_output } " ,
10411041 )
10421042
1043+ @expand (
1044+ [
1045+ (
1046+ "basic_int8_weights" ,
1047+ torch .tensor (
1048+ [
1049+ [
1050+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1051+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1052+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1053+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1054+ ]
1055+ ],
1056+ dtype = torch .float32 ,
1057+ ), # src: 1x4x5
1058+ torch .tensor (
1059+ [
1060+ [[1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ]],
1061+ [[1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ]],
1062+ [[1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ]],
1063+ [[1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ], [1 , - 1 , 2 ]],
1064+ ],
1065+ dtype = torch .int8 ,
1066+ ), # weight: 4x4x3
1067+ 0.1 , # w_scale
1068+ torch .tensor ([1 , 1 , 1 , 1 ], dtype = torch .int8 ), # bias: 4
1069+ 0.2 , # b_scale
1070+ torch .tensor (
1071+ [
1072+ [
1073+ [2.2 , 3.0 , 3.8 ],
1074+ [2.2 , 3.0 , 3.8 ],
1075+ [2.2 , 3.0 , 3.8 ],
1076+ [2.2 , 3.0 , 3.8 ],
1077+ ]
1078+ ],
1079+ dtype = torch .float32 ,
1080+ ), # expected: conv1d result
1081+ ),
1082+ (
1083+ "batch_size_2" ,
1084+ torch .tensor (
1085+ [
1086+ [
1087+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1088+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1089+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1090+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1091+ ],
1092+ [
1093+ [2.0 , 3.0 , 4.0 , 5.0 , 6.0 ],
1094+ [2.0 , 3.0 , 4.0 , 5.0 , 6.0 ],
1095+ [2.0 , 3.0 , 4.0 , 5.0 , 6.0 ],
1096+ [2.0 , 3.0 , 4.0 , 5.0 , 6.0 ],
1097+ ],
1098+ ],
1099+ dtype = torch .float32 ,
1100+ ), # src: 2x4x5
1101+ torch .tensor (
1102+ [
1103+ [[1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ]],
1104+ [[1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ]],
1105+ [[1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ]],
1106+ [[1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ], [1 , 1 , 1 ]],
1107+ ],
1108+ dtype = torch .int8 ,
1109+ ), # weight: 4x4x3
1110+ 1.0 , # w_scale
1111+ torch .tensor ([0 , 0 , 0 , 0 ], dtype = torch .int8 ), # bias: 4
1112+ 1.0 , # b_scale
1113+ torch .tensor (
1114+ [
1115+ [
1116+ [24.0 , 36.0 , 48.0 ],
1117+ [24.0 , 36.0 , 48.0 ],
1118+ [24.0 , 36.0 , 48.0 ],
1119+ [24.0 , 36.0 , 48.0 ],
1120+ ],
1121+ [
1122+ [36.0 , 48.0 , 60.0 ],
1123+ [36.0 , 48.0 , 60.0 ],
1124+ [36.0 , 48.0 , 60.0 ],
1125+ [36.0 , 48.0 , 60.0 ],
1126+ ],
1127+ ],
1128+ dtype = torch .float32 ,
1129+ ), # expected
1130+ ),
1131+ (
1132+ "zero_weights_bias" ,
1133+ torch .tensor (
1134+ [
1135+ [
1136+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1137+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1138+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1139+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ],
1140+ ]
1141+ ],
1142+ dtype = torch .float32 ,
1143+ ), # src: 1x4x5
1144+ torch .tensor (
1145+ [
1146+ [[0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ]],
1147+ [[0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ]],
1148+ [[0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ]],
1149+ [[0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 ]],
1150+ ],
1151+ dtype = torch .int8 ,
1152+ ), # weight: 4x4x3
1153+ 0.1 , # w_scale
1154+ torch .tensor ([0 , 0 , 0 , 0 ], dtype = torch .int8 ), # bias: 4
1155+ 1.0 , # b_scale
1156+ torch .tensor (
1157+ [
1158+ [
1159+ [0.0 , 0.0 , 0.0 ],
1160+ [0.0 , 0.0 , 0.0 ],
1161+ [0.0 , 0.0 , 0.0 ],
1162+ [0.0 , 0.0 , 0.0 ],
1163+ ]
1164+ ],
1165+ dtype = torch .float32 ,
1166+ ), # expected
1167+ ),
1168+ (
1169+ "negative_weights" ,
1170+ torch .tensor (
1171+ [
1172+ [
1173+ [2.0 , 4.0 , 6.0 , 8.0 , 10.0 ],
1174+ [2.0 , 4.0 , 6.0 , 8.0 , 10.0 ],
1175+ [2.0 , 4.0 , 6.0 , 8.0 , 10.0 ],
1176+ [2.0 , 4.0 , 6.0 , 8.0 , 10.0 ],
1177+ ]
1178+ ],
1179+ dtype = torch .float32 ,
1180+ ), # src: 1x4x5
1181+ torch .tensor (
1182+ [
1183+ [[- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ]],
1184+ [[- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ]],
1185+ [[- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ]],
1186+ [[- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ], [- 2 , - 1 , 0 ]],
1187+ ],
1188+ dtype = torch .int8 ,
1189+ ), # weight: 4x4x3
1190+ 0.5 , # w_scale
1191+ torch .tensor ([2 , 2 , 2 , 2 ], dtype = torch .float32 ), # bias: 4
1192+ 1.0 , # b_scale
1193+ torch .tensor (
1194+ [
1195+ [
1196+ [- 14.0 , - 26.0 , - 38.0 ],
1197+ [- 14.0 , - 26.0 , - 38.0 ],
1198+ [- 14.0 , - 26.0 , - 38.0 ],
1199+ [- 14.0 , - 26.0 , - 38.0 ],
1200+ ]
1201+ ],
1202+ dtype = torch .float32 ,
1203+ ), # expected
1204+ ),
1205+ ]
1206+ )
1207+ def test_quantized_w8a32_conv (
1208+ self ,
1209+ name : str ,
1210+ src : torch .Tensor ,
1211+ weight : torch .Tensor ,
1212+ w_scale : float ,
1213+ bias : torch .Tensor ,
1214+ b_scale : float ,
1215+ expected_output : torch .Tensor ,
1216+ ) -> None :
1217+ output = torch .ops .cadence .quantized_w8a32_conv (
1218+ src , weight , w_scale , bias , b_scale
1219+ )
1220+
1221+ # Verify output properties
1222+ self .assertEqual (
1223+ output .dtype ,
1224+ torch .float32 ,
1225+ f"Output dtype should be float32 in { name } " ,
1226+ )
1227+ self .assertEqual (
1228+ output .shape ,
1229+ expected_output .shape ,
1230+ f"Output shape should match expected shape in { name } " ,
1231+ )
1232+
1233+ # Verify output matches expected values
1234+ self .assertTrue (
1235+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1236+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1237+ )
1238+
10431239 @expand (
10441240 [
10451241 # Test case 1: Basic int8 case with negative scale
0 commit comments