@@ -354,46 +354,124 @@ def linear_q8ta_q8csw(
354
354
lib .impl (name , linear_q8ta_q8csw , "CompositeExplicitAutograd" )
355
355
qa_q8csw_linear = getattr (getattr (torch .ops , namespace ), name )
356
356
357
- #######################
358
- ## conv2d_q8ta_q8csw ##
359
- #######################
357
+ ############################
358
+ ## conv2d_q8ta_q8csw_q8to ##
359
+ ############################
360
360
361
361
362
- def conv2d_q8ta_q8csw (
362
+ def conv2d_q8ta_q8csw_q8to (
363
363
x : torch .Tensor ,
364
364
input_scale : float ,
365
365
input_zero_point : int ,
366
366
weights : torch .Tensor ,
367
367
weight_sums : torch .Tensor ,
368
368
weight_scales : torch .Tensor ,
369
+ output_scale : float ,
370
+ output_zero_point : int ,
369
371
bias : Optional [torch .Tensor ],
370
372
kernel_size : list ,
371
373
stride : list ,
372
374
padding : list ,
373
375
dilation : list ,
374
376
groups : int ,
375
377
):
376
- IC = x .shape [1 ]
378
+ x = torch .ops .quantized_decomposed .dequantize_per_tensor (
379
+ x , input_scale , input_zero_point , - 128 , 127 , x .dtype
380
+ )
381
+
382
+ # Calculate weight dimensions
383
+ OC = weights .shape [0 ]
384
+ assert OC % groups == 0 , "Output channels must be divisible by groups"
385
+ IC_per_group = int (x .shape [1 ] / groups )
377
386
K_h , K_w = kernel_size [0 ], kernel_size [1 ]
378
387
379
- canonical_weight_K_dim = K_h * K_w * IC
388
+ orig_weight_K_dim = K_h * K_w * IC_per_group
389
+ # Remove any padding added to in_features dim to align to a multiple of 4
390
+ if weights .shape [- 1 ] > orig_weight_K_dim :
391
+ weights = weights [:, :orig_weight_K_dim ]
392
+
380
393
# Remove any padding added to output channels dim to align to a multiple of 4
381
- if weights .shape [- 1 ] != canonical_weight_K_dim :
382
- weights = weights [:, :canonical_weight_K_dim ]
383
- weight_scales = weight_scales [:canonical_weight_K_dim ]
394
+ if weight_scales .shape [0 ] > OC :
395
+ weight_scales = weight_scales [:OC ]
384
396
if bias is not None :
385
- bias = bias [:canonical_weight_K_dim ]
397
+ bias = bias [:OC ]
398
+
399
+ # Reshape to original 4D format (OC, IC, H, W)
400
+ weights = weights .view (OC , IC_per_group , K_h , K_w )
386
401
387
402
weight_zeros = torch .zeros_like (weight_scales , dtype = torch .int32 )
403
+ # Dequantize weights
404
+ weights = torch .ops .quantized_decomposed .dequantize_per_channel (
405
+ weights ,
406
+ weight_scales ,
407
+ weight_zeros ,
408
+ 0 , # axis=0 for output channel quantization
409
+ - 127 ,
410
+ 127 ,
411
+ torch .int8 ,
412
+ )
388
413
389
- # Calculate dimensions
390
- OC = weights . shape [ 0 ]
391
- in_features = weights . shape [ 1 ]
392
- IC = in_features // ( K_h * K_w )
414
+ # Perform convolution
415
+ out = torch . nn . functional . conv2d (
416
+ x , weights , bias , stride , padding , dilation , groups
417
+ )
393
418
394
- # Reshape to original 4D format (OC, IC, H, W)
395
- weights = weights .view (OC , IC , K_h , K_w )
419
+ out = torch .ops .quantized_decomposed .quantize_per_tensor (
420
+ out , output_scale , output_zero_point , - 128 , 127 , torch .int8
421
+ )
422
+
423
+ return out
396
424
425
+
426
+ name = "conv2d_q8ta_q8csw_q8to"
427
+ lib .define (
428
+ f"""
429
+ { name } (
430
+ Tensor x,
431
+ float input_scale,
432
+ int input_zero_point,
433
+ Tensor weights,
434
+ Tensor weight_sums,
435
+ Tensor weight_scales,
436
+ float output_scale,
437
+ int output_zero_point,
438
+ Tensor? bias,
439
+ SymInt[] kernel_size,
440
+ SymInt[] stride,
441
+ SymInt[] padding,
442
+ SymInt[] dilation,
443
+ SymInt groups) -> Tensor
444
+ """
445
+ )
446
+ lib .impl (name , conv2d_q8ta_q8csw_q8to , "CompositeExplicitAutograd" )
447
+ conv2d_q8ta_q8csw_op = getattr (getattr (torch .ops , namespace ), name )
448
+
449
+
450
+ def conv2d_q8ta_q8csw_q8to_dw (
451
+ x : torch .Tensor ,
452
+ input_scale : float ,
453
+ input_zero_point : int ,
454
+ weights : torch .Tensor ,
455
+ weight_sums : torch .Tensor ,
456
+ weight_scales : torch .Tensor ,
457
+ output_scale : float ,
458
+ output_zero_point : int ,
459
+ bias : Optional [torch .Tensor ],
460
+ kernel_size : list ,
461
+ stride : list ,
462
+ padding : list ,
463
+ dilation : list ,
464
+ groups : int ,
465
+ ):
466
+ x = torch .ops .quantized_decomposed .dequantize_per_tensor (
467
+ x , input_scale , input_zero_point , - 128 , 127 , x .dtype
468
+ )
469
+
470
+ # Restore weight to original data layout
471
+ K_h , K_w , OC = weights .shape
472
+ weights = weights .permute (2 , 0 , 1 ).reshape (OC , 1 , K_h , K_w )
473
+
474
+ weight_zeros = torch .zeros_like (weight_scales , dtype = torch .int32 )
397
475
# Dequantize weights
398
476
weights = torch .ops .quantized_decomposed .dequantize_per_channel (
399
477
weights ,
@@ -410,10 +488,14 @@ def conv2d_q8ta_q8csw(
410
488
x , weights , bias , stride , padding , dilation , groups
411
489
)
412
490
491
+ out = torch .ops .quantized_decomposed .quantize_per_tensor (
492
+ out , output_scale , output_zero_point , - 128 , 127 , torch .int8
493
+ )
494
+
413
495
return out
414
496
415
497
416
- name = "conv2d_q8ta_q8csw "
498
+ name = "conv2d_q8ta_q8csw_q8to_dw "
417
499
lib .define (
418
500
f"""
419
501
{ name } (
@@ -423,6 +505,8 @@ def conv2d_q8ta_q8csw(
423
505
Tensor weights,
424
506
Tensor weight_sums,
425
507
Tensor weight_scales,
508
+ float output_scale,
509
+ int output_zero_point,
426
510
Tensor? bias,
427
511
SymInt[] kernel_size,
428
512
SymInt[] stride,
@@ -431,8 +515,8 @@ def conv2d_q8ta_q8csw(
431
515
SymInt groups) -> Tensor
432
516
"""
433
517
)
434
- lib .impl (name , conv2d_q8ta_q8csw , "CompositeExplicitAutograd" )
435
- conv2d_q8ta_q8csw_op = getattr (getattr (torch .ops , namespace ), name )
518
+ lib .impl (name , conv2d_q8ta_q8csw_q8to_dw , "CompositeExplicitAutograd" )
519
+ conv2d_q8ta_q8csw_dw_op = getattr (getattr (torch .ops , namespace ), name )
436
520
437
521
######################
438
522
## apply_rotary_emb ##
@@ -452,3 +536,39 @@ def apply_rotary_emb_impl(
452
536
)
453
537
lib .impl (name , apply_rotary_emb_impl , "CompositeExplicitAutograd" )
454
538
apply_rotary_emb_op = getattr (getattr (torch .ops , namespace ), name )
539
+
540
+ #############################
541
+ ## quantize/dequantize ops ##
542
+ #############################
543
+
544
+
545
+ def quantize_q8ta_for_conv2d_impl (
546
+ input : torch .Tensor ,
547
+ scale : float ,
548
+ zero_point : int ,
549
+ ):
550
+ return torch .ops .quantized_decomposed .quantize_per_tensor (
551
+ input , scale , zero_point , - 128 , 127 , torch .int8
552
+ )
553
+
554
+
555
+ name = "quantize_q8ta_for_conv2d"
556
+ lib .define (f"{ name } (Tensor input, float scale, int zero_point) -> Tensor" )
557
+ lib .impl (name , quantize_q8ta_for_conv2d_impl , "CompositeExplicitAutograd" )
558
+ quantize_q8ta_for_conv2d_op = getattr (getattr (torch .ops , namespace ), name )
559
+
560
+
561
+ def dequantize_q8to_from_conv2d_impl (
562
+ input : torch .Tensor ,
563
+ scale : float ,
564
+ zero_point : int ,
565
+ ):
566
+ return torch .ops .quantized_decomposed .dequantize_per_tensor (
567
+ input , scale , zero_point , - 128 , 127 , input .dtype
568
+ )
569
+
570
+
571
+ name = "dequantize_q8to_from_conv2d"
572
+ lib .define (f"{ name } (Tensor input, float scale, int zero_point) -> Tensor" )
573
+ lib .impl (name , dequantize_q8to_from_conv2d_impl , "CompositeExplicitAutograd" )
574
+ dequantize_q8to_from_conv2d_op = getattr (getattr (torch .ops , namespace ), name )
0 commit comments