@@ -1188,7 +1188,7 @@ def test_quantized_conv_per_tensor(
11881188 dtype = torch .int8 ,
11891189 ), # weight: 4x4x3
11901190 0.5 , # w_scale
1191- torch .tensor ([2 , 2 , 2 , 2 ], dtype = torch .float32 ), # bias: 4
1191+ torch .tensor ([2 , 2 , 2 , 2 ], dtype = torch .int8 ), # bias: 4
11921192 1.0 , # b_scale
11931193 torch .tensor (
11941194 [
@@ -1214,6 +1214,12 @@ def test_quantized_w8a32_conv(
12141214 b_scale : float ,
12151215 expected_output : torch .Tensor ,
12161216 ) -> None :
1217+
1218+ # This op takes in channels last src
1219+ src = src .permute (0 , 2 , 1 )
1220+
1221+ # This op takes in LNC format for weights
1222+ weight = weight .permute (2 , 0 , 1 )
12171223 output = torch .ops .cadence .quantized_w8a32_conv (
12181224 src , weight , w_scale , bias , b_scale
12191225 )
@@ -1236,6 +1242,95 @@ def test_quantized_w8a32_conv(
12361242 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
12371243 )
12381244
1245+ @expand (
1246+ [
1247+ (
1248+ "multi_input_features" ,
1249+ torch .tensor ([[1.0 , 2.0 , 3.0 ]], dtype = torch .float32 ), # src: 1x3
1250+ torch .tensor ([[2 , 1 ], [1 , 2 ], [1 , 1 ]], dtype = torch .int8 ), # weight: 3x2
1251+ 0.5 , # w_scale
1252+ torch .tensor ([0 , 1 ], dtype = torch .int8 ), # bias: 2
1253+ 1.0 , # b_scale
1254+ torch .tensor ([[3.5 , 5.0 ]], dtype = torch .float32 ), # expected
1255+ ),
1256+ (
1257+ "batch_size_2" ,
1258+ torch .tensor (
1259+ [[[1.0 , 2.0 ]], [[3.0 , 4.0 ]]], dtype = torch .float32
1260+ ), # src: 2x2
1261+ torch .tensor ([[1 , 2 ], [1 , - 1 ]], dtype = torch .int8 ), # weight: 2x2
1262+ 1.0 , # w_scale
1263+ torch .tensor ([0 , 0 ], dtype = torch .int8 ), # bias: 2
1264+ 1.0 , # b_scale
1265+ torch .tensor (
1266+ [[[3.0 , 0.0 ]], [[7.0 , 2.0 ]]], dtype = torch .float32
1267+ ), # expected
1268+ ),
1269+ (
1270+ "shape_assertion_error" ,
1271+ torch .tensor (
1272+ [[[1.0 , 2.0 ], [3.0 , 4.0 ]]], dtype = torch .float32
1273+ ), # src: 1x2x2
1274+ torch .tensor ([[1 , 2 ], [1 , - 1 ]], dtype = torch .int8 ), # weight: 2x2
1275+ 1.0 , # w_scale
1276+ torch .tensor ([0 , 1 ], dtype = torch .int8 ), # bias: 2
1277+ 1.0 , # b_scale
1278+ torch .tensor (
1279+ [[[3.0 , 1.0 ], [7.0 , 3.0 ]]], dtype = torch .float32
1280+ ), # expected
1281+ ),
1282+ (
1283+ "negative_weights" ,
1284+ torch .tensor ([[2.0 , 4.0 ]], dtype = torch .float32 ), # src: 1x2
1285+ torch .tensor ([[- 2 , - 3 ], [- 1 , - 2 ]], dtype = torch .int8 ), # weight: 2x2
1286+ 0.5 , # w_scale
1287+ torch .tensor ([2 , 1 ], dtype = torch .int8 ), # bias: 2
1288+ 1.0 , # b_scale
1289+ torch .tensor ([[- 2.0 , - 6.0 ]], dtype = torch .float32 ), # expected
1290+ ),
1291+ ]
1292+ )
1293+ def test_quantized_w8a32_linear (
1294+ self ,
1295+ name : str ,
1296+ src : torch .Tensor ,
1297+ weight : torch .Tensor ,
1298+ w_scale : float ,
1299+ bias : torch .Tensor ,
1300+ b_scale : float ,
1301+ expected_output : torch .Tensor ,
1302+ ) -> None :
1303+ if name == "shape_assertion_error" :
1304+ with self .assertRaisesRegex (
1305+ AssertionError , "Only supporting vector-matrix multiplication"
1306+ ):
1307+ torch .ops .cadence .quantized_w8a32_linear (
1308+ src , weight , w_scale , bias , b_scale
1309+ )
1310+ return
1311+
1312+ output = torch .ops .cadence .quantized_w8a32_linear (
1313+ src , weight , w_scale , bias , b_scale
1314+ )
1315+
1316+ # Verify output properties
1317+ self .assertEqual (
1318+ output .dtype ,
1319+ torch .float32 ,
1320+ f"Output dtype should be float32 in { name } " ,
1321+ )
1322+ self .assertEqual (
1323+ output .shape ,
1324+ expected_output .shape ,
1325+ f"Output shape should match expected shape in { name } " ,
1326+ )
1327+
1328+ # Verify output matches expected values
1329+ self .assertTrue (
1330+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1331+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1332+ )
1333+
12391334 @expand (
12401335 [
12411336 # Test case 1: Basic int8 case with negative scale
0 commit comments