@@ -177,6 +177,8 @@ def test_quantized_add(
177
177
0 , # out_zero_point
178
178
torch .tensor ([[- 2 ]], dtype = dtype ), # expected_output
179
179
per_tensor ,
180
+ False ,
181
+ False ,
180
182
)
181
183
for (per_tensor , dtype ) in (
182
184
(False , torch .int8 ),
@@ -200,6 +202,8 @@ def test_quantized_add(
200
202
0 , # out_zero_point
201
203
torch .tensor ([[- 10 , - 30 ]], dtype = dtype ), # expected_output
202
204
per_tensor ,
205
+ False ,
206
+ False ,
203
207
)
204
208
for (per_tensor , dtype ) in (
205
209
(False , torch .int8 ),
@@ -225,6 +229,8 @@ def test_quantized_add(
225
229
[[[- 2 , - 8 , - 14 ], [- 6 , - 28 , - 50 ]]], dtype = dtype
226
230
), # expected_output
227
231
per_tensor ,
232
+ False ,
233
+ False ,
228
234
)
229
235
for (per_tensor , dtype ) in (
230
236
(False , torch .int8 ),
@@ -248,6 +254,8 @@ def test_quantized_add(
248
254
1 , # out_zero_point
249
255
torch .tensor ([[- 15 , 25 ]], dtype = dtype ), # expected_output
250
256
per_tensor ,
257
+ False ,
258
+ False ,
251
259
)
252
260
for (per_tensor , dtype ) in (
253
261
(False , torch .int8 ),
@@ -271,6 +279,8 @@ def test_quantized_add(
271
279
1 , # out_zero_point
272
280
torch .tensor ([[- 23 , 17 ]], dtype = dtype ), # expected_output
273
281
False ,
282
+ False ,
283
+ False ,
274
284
)
275
285
for dtype in (torch .int8 , torch .uint8 )
276
286
],
@@ -292,9 +302,34 @@ def test_quantized_add(
292
302
1 , # out_zero_point
293
303
torch .tensor ([[- 7 , 13 ]], dtype = dtype ), # expected_output
294
304
per_tensor ,
305
+ False ,
306
+ False ,
295
307
)
296
308
for (per_tensor , dtype ) in ((False , torch .int8 ), (True , torch .int8 ))
297
309
],
310
+ * [
311
+ (
312
+ torch .Size ([1 , 2 ]), # src_shape: 1 sample, 2 input features
313
+ torch .Size (
314
+ [2 , 2 ]
315
+ ), # weight_shape: 2 output features, 2 input features
316
+ 2 , # in_zero_point
317
+ torch .tensor ([1 , 1 ], dtype = dtype ), # weight_zero_point
318
+ torch .tensor (
319
+ [268435456 ], dtype = torch .int32
320
+ ), # out_multiplier (0.125 * 2^31)
321
+ torch .tensor (
322
+ [1 ], dtype = torch .int64
323
+ ), # out_shift (shift=1, doubles the scale)
324
+ 1 , # out_zero_point
325
+ torch .tensor ([[- 7 , 17 ]], dtype = dtype ), # expected_output
326
+ per_tensor ,
327
+ matmul ,
328
+ transposed_matmul ,
329
+ )
330
+ for (matmul , transposed_matmul ) in ((True , False ), (True , True ))
331
+ for (per_tensor , dtype ) in ((True , torch .int8 ), (True , torch .uint8 ))
332
+ ],
298
333
]
299
334
)
300
335
def test_quantized_linear (
@@ -308,7 +343,12 @@ def test_quantized_linear(
308
343
out_zero_point : int ,
309
344
expected_output : torch .Tensor ,
310
345
per_tensor : bool ,
346
+ matmul : bool ,
347
+ transposed_matmul : bool ,
311
348
) -> None :
349
+ if not per_tensor and matmul :
350
+ self .skipTest ("Only per_tensor supported for matmul" )
351
+
312
352
src = (
313
353
torch .arange (np .prod (src_shape ))
314
354
.reshape (src_shape )
@@ -319,7 +359,9 @@ def test_quantized_linear(
319
359
.reshape (weight_shape )
320
360
.to (expected_output .dtype )
321
361
)
322
- bias = torch .arange (weight_shape [0 ]).to (torch .int32 )
362
+ if matmul and not transposed_matmul :
363
+ weight = weight .T
364
+
323
365
if per_tensor :
324
366
weight_zero_point = weight_zero_point [0 ]
325
367
out_multiplier = out_multiplier [0 ]
@@ -328,38 +370,75 @@ def test_quantized_linear(
328
370
if per_tensor :
329
371
match expected_output .dtype :
330
372
case torch .int8 :
331
- linear_ops = (
332
- torch .ops .cadence .quantized_linear_asym8sxasym8s_asym8s .per_tensor ,
333
- torch .ops .cadence .quantized_fully_connected_asym8sxasym8s_asym8s .per_tensor ,
334
- )
373
+ if matmul :
374
+ linear_ops = (
375
+ # Doesn't have per tensor name, but it is per tensor
376
+ torch .ops .cadence .quantized_matmul_asym8sxasym8s_asym8s ,
377
+ )
378
+ else :
379
+ linear_ops = (
380
+ torch .ops .cadence .quantized_linear_asym8sxasym8s_asym8s .per_tensor ,
381
+ torch .ops .cadence .quantized_fully_connected_asym8sxasym8s_asym8s .per_tensor ,
382
+ )
335
383
case torch .uint8 :
336
- linear_ops = (
337
- torch .ops .cadence .quantized_linear_asym8uxasym8u_asym8u .per_tensor ,
338
- torch .ops .cadence .quantized_fully_connected_asym8uxasym8u_asym8u .per_tensor ,
339
- )
384
+ if matmul :
385
+ linear_ops = (
386
+ torch .ops .cadence .quantized_matmul_asym8uxasym8u_asym8u ,
387
+ )
388
+ else :
389
+ linear_ops = (
390
+ torch .ops .cadence .quantized_linear_asym8uxasym8u_asym8u .per_tensor ,
391
+ torch .ops .cadence .quantized_fully_connected_asym8uxasym8u_asym8u .per_tensor ,
392
+ )
340
393
case _:
341
- linear_ops = (
342
- torch .ops .cadence .quantized_linear .per_tensor ,
343
- torch .ops .cadence .quantized_fully_connected .per_tensor ,
344
- )
394
+ if matmul :
395
+ linear_ops = (torch .ops .cadence .quantized_matmul ,)
396
+ else :
397
+ linear_ops = (
398
+ torch .ops .cadence .quantized_linear .per_tensor ,
399
+ torch .ops .cadence .quantized_fully_connected .per_tensor ,
400
+ )
345
401
else :
346
402
linear_ops = (
347
403
torch .ops .cadence .quantized_linear ,
348
404
torch .ops .cadence .quantized_fully_connected ,
349
405
)
350
406
351
407
for linear_op in linear_ops :
352
- output = linear_op (
353
- src ,
354
- weight ,
355
- bias ,
356
- in_zero_point ,
357
- weight_zero_point ,
358
- out_multiplier ,
359
- out_shift ,
360
- out_zero_point ,
361
- typing .cast (torch .Tensor , None ),
408
+ # Get the function name for linear_op for debugging
409
+ op_name = (
410
+ linear_op .__name__ if hasattr (linear_op , "__name__" ) else str (linear_op )
362
411
)
412
+ if matmul :
413
+ assert "quantized_matmul" in op_name
414
+ output = linear_op (
415
+ src ,
416
+ in_zero_point ,
417
+ weight ,
418
+ weight_zero_point ,
419
+ None ,
420
+ out_multiplier ,
421
+ out_shift ,
422
+ out_zero_point ,
423
+ transposed_matmul ,
424
+ )
425
+ else :
426
+ assert (
427
+ "quantized_linear" in op_name
428
+ or "quantized_fully_connected" in op_name
429
+ )
430
+ bias = torch .arange (weight_shape [0 ]).to (torch .int32 )
431
+ output = linear_op (
432
+ src ,
433
+ weight ,
434
+ bias ,
435
+ in_zero_point ,
436
+ weight_zero_point ,
437
+ out_multiplier ,
438
+ out_shift ,
439
+ out_zero_point ,
440
+ typing .cast (torch .Tensor , None ),
441
+ )
363
442
364
443
self .assertTrue (output .dtype == expected_output .dtype , "Dtype mismatch" )
365
444
0 commit comments