@@ -177,6 +177,8 @@ def test_quantized_add(
177177 0 , # out_zero_point
178178 torch .tensor ([[- 2 ]], dtype = dtype ), # expected_output
179179 per_tensor ,
180+ False ,
181+ False ,
180182 )
181183 for (per_tensor , dtype ) in (
182184 (False , torch .int8 ),
@@ -200,6 +202,8 @@ def test_quantized_add(
200202 0 , # out_zero_point
201203 torch .tensor ([[- 10 , - 30 ]], dtype = dtype ), # expected_output
202204 per_tensor ,
205+ False ,
206+ False ,
203207 )
204208 for (per_tensor , dtype ) in (
205209 (False , torch .int8 ),
@@ -225,6 +229,8 @@ def test_quantized_add(
225229 [[[- 2 , - 8 , - 14 ], [- 6 , - 28 , - 50 ]]], dtype = dtype
226230 ), # expected_output
227231 per_tensor ,
232+ False ,
233+ False ,
228234 )
229235 for (per_tensor , dtype ) in (
230236 (False , torch .int8 ),
@@ -248,6 +254,8 @@ def test_quantized_add(
248254 1 , # out_zero_point
249255 torch .tensor ([[- 15 , 25 ]], dtype = dtype ), # expected_output
250256 per_tensor ,
257+ False ,
258+ False ,
251259 )
252260 for (per_tensor , dtype ) in (
253261 (False , torch .int8 ),
@@ -271,6 +279,8 @@ def test_quantized_add(
271279 1 , # out_zero_point
272280 torch .tensor ([[- 23 , 17 ]], dtype = dtype ), # expected_output
273281 False ,
282+ False ,
283+ False ,
274284 )
275285 for dtype in (torch .int8 , torch .uint8 )
276286 ],
@@ -292,9 +302,34 @@ def test_quantized_add(
292302 1 , # out_zero_point
293303 torch .tensor ([[- 7 , 13 ]], dtype = dtype ), # expected_output
294304 per_tensor ,
305+ False ,
306+ False ,
295307 )
296308 for (per_tensor , dtype ) in ((False , torch .int8 ), (True , torch .int8 ))
297309 ],
310+ * [
311+ (
312+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
313+ torch .Size (
314+ [2 , 2 ]
315+ ), # weight_shape: 2 output features, 2 input features
316+ 2 , # in_zero_point
317+ torch .tensor ([1 , 1 ], dtype = dtype ), # weight_zero_point
318+ torch .tensor (
319+ [268435456 ], dtype = torch .int32
320+ ), # out_multiplier (0.125 * 2^31)
321+ torch .tensor (
322+ [1 ], dtype = torch .int64
323+ ), # out_shift (shift=1, doubles the scale)
324+ 1 , # out_zero_point
325+ torch .tensor ([[- 7 , 17 ]], dtype = dtype ), # expected_output
326+ per_tensor ,
327+ matmul ,
328+ transposed_matmul ,
329+ )
330+ for (matmul , transposed_matmul ) in ((True , False ), (True , True ))
331+ for (per_tensor , dtype ) in ((True , torch .int8 ), (True , torch .uint8 ))
332+ ],
298333 ]
299334 )
300335 def test_quantized_linear (
@@ -308,7 +343,12 @@ def test_quantized_linear(
308343 out_zero_point : int ,
309344 expected_output : torch .Tensor ,
310345 per_tensor : bool ,
346+ matmul : bool ,
347+ transposed_matmul : bool ,
311348 ) -> None :
349+ if not per_tensor and matmul :
350+ self .skipTest ("Only per_tensor supported for matmul" )
351+
312352 src = (
313353 torch .arange (np .prod (src_shape ))
314354 .reshape (src_shape )
@@ -319,7 +359,9 @@ def test_quantized_linear(
319359 .reshape (weight_shape )
320360 .to (expected_output .dtype )
321361 )
322- bias = torch .arange (weight_shape [0 ]).to (torch .int32 )
362+ if matmul and not transposed_matmul :
363+ weight = weight .T
364+
323365 if per_tensor :
324366 weight_zero_point = weight_zero_point [0 ]
325367 out_multiplier = out_multiplier [0 ]
@@ -328,38 +370,75 @@ def test_quantized_linear(
328370 if per_tensor :
329371 match expected_output .dtype :
330372 case torch .int8 :
331- linear_ops = (
332- torch .ops .cadence .quantized_linear_asym8sxasym8s_asym8s .per_tensor ,
333- torch .ops .cadence .quantized_fully_connected_asym8sxasym8s_asym8s .per_tensor ,
334- )
373+ if matmul :
374+ linear_ops = (
375+ # Doesn't have per tensor name, but it is per tensor
376+ torch .ops .cadence .quantized_matmul_asym8sxasym8s_asym8s ,
377+ )
378+ else :
379+ linear_ops = (
380+ torch .ops .cadence .quantized_linear_asym8sxasym8s_asym8s .per_tensor ,
381+ torch .ops .cadence .quantized_fully_connected_asym8sxasym8s_asym8s .per_tensor ,
382+ )
335383 case torch .uint8 :
336- linear_ops = (
337- torch .ops .cadence .quantized_linear_asym8uxasym8u_asym8u .per_tensor ,
338- torch .ops .cadence .quantized_fully_connected_asym8uxasym8u_asym8u .per_tensor ,
339- )
384+ if matmul :
385+ linear_ops = (
386+ torch .ops .cadence .quantized_matmul_asym8uxasym8u_asym8u ,
387+ )
388+ else :
389+ linear_ops = (
390+ torch .ops .cadence .quantized_linear_asym8uxasym8u_asym8u .per_tensor ,
391+ torch .ops .cadence .quantized_fully_connected_asym8uxasym8u_asym8u .per_tensor ,
392+ )
340393 case _:
341- linear_ops = (
342- torch .ops .cadence .quantized_linear .per_tensor ,
343- torch .ops .cadence .quantized_fully_connected .per_tensor ,
344- )
394+ if matmul :
395+ linear_ops = (torch .ops .cadence .quantized_matmul ,)
396+ else :
397+ linear_ops = (
398+ torch .ops .cadence .quantized_linear .per_tensor ,
399+ torch .ops .cadence .quantized_fully_connected .per_tensor ,
400+ )
345401 else :
346402 linear_ops = (
347403 torch .ops .cadence .quantized_linear ,
348404 torch .ops .cadence .quantized_fully_connected ,
349405 )
350406
351407 for linear_op in linear_ops :
352- output = linear_op (
353- src ,
354- weight ,
355- bias ,
356- in_zero_point ,
357- weight_zero_point ,
358- out_multiplier ,
359- out_shift ,
360- out_zero_point ,
361- typing .cast (torch .Tensor , None ),
408+ # Get the function name for linear_op for debugging
409+ op_name = (
410+ linear_op .__name__ if hasattr (linear_op , "__name__" ) else str (linear_op )
362411 )
412+ if matmul :
413+ assert "quantized_matmul" in op_name
414+ output = linear_op (
415+ src ,
416+ in_zero_point ,
417+ weight ,
418+ weight_zero_point ,
419+ None ,
420+ out_multiplier ,
421+ out_shift ,
422+ out_zero_point ,
423+ transposed_matmul ,
424+ )
425+ else :
426+ assert (
427+ "quantized_linear" in op_name
428+ or "quantized_fully_connected" in op_name
429+ )
430+ bias = torch .arange (weight_shape [0 ]).to (torch .int32 )
431+ output = linear_op (
432+ src ,
433+ weight ,
434+ bias ,
435+ in_zero_point ,
436+ weight_zero_point ,
437+ out_multiplier ,
438+ out_shift ,
439+ out_zero_point ,
440+ typing .cast (torch .Tensor , None ),
441+ )
363442
364443 self .assertTrue (output .dtype == expected_output .dtype , "Dtype mismatch" )
365444
0 commit comments