@@ -354,46 +354,124 @@ def linear_q8ta_q8csw(
354354lib .impl (name , linear_q8ta_q8csw , "CompositeExplicitAutograd" )
355355qa_q8csw_linear = getattr (getattr (torch .ops , namespace ), name )
356356
357- #######################
358- ## conv2d_q8ta_q8csw ##
359- #######################
357+ ############################
358+ ## conv2d_q8ta_q8csw_q8to ##
359+ ############################
360360
361361
362- def conv2d_q8ta_q8csw (
362+ def conv2d_q8ta_q8csw_q8to (
363363 x : torch .Tensor ,
364364 input_scale : float ,
365365 input_zero_point : int ,
366366 weights : torch .Tensor ,
367367 weight_sums : torch .Tensor ,
368368 weight_scales : torch .Tensor ,
369+ output_scale : float ,
370+ output_zero_point : int ,
369371 bias : Optional [torch .Tensor ],
370372 kernel_size : list ,
371373 stride : list ,
372374 padding : list ,
373375 dilation : list ,
374376 groups : int ,
375377):
376- IC = x .shape [1 ]
378+ x = torch .ops .quantized_decomposed .dequantize_per_tensor (
379+ x , input_scale , input_zero_point , - 128 , 127 , x .dtype
380+ )
381+
382+ # Calculate weight dimensions
383+ OC = weights .shape [0 ]
384+ assert OC % groups == 0 , "Output channels must be divisible by groups"
385+ IC_per_group = int (x .shape [1 ] / groups )
377386 K_h , K_w = kernel_size [0 ], kernel_size [1 ]
378387
379- canonical_weight_K_dim = K_h * K_w * IC
388+ orig_weight_K_dim = K_h * K_w * IC_per_group
389+ # Remove any padding added to in_features dim to align to a multiple of 4
390+ if weights .shape [- 1 ] > orig_weight_K_dim :
391+ weights = weights [:, :orig_weight_K_dim ]
392+
380393 # Remove any padding added to output channels dim to align to a multiple of 4
381- if weights .shape [- 1 ] != canonical_weight_K_dim :
382- weights = weights [:, :canonical_weight_K_dim ]
383- weight_scales = weight_scales [:canonical_weight_K_dim ]
394+ if weight_scales .shape [0 ] > OC :
395+ weight_scales = weight_scales [:OC ]
384396 if bias is not None :
385- bias = bias [:canonical_weight_K_dim ]
397+ bias = bias [:OC ]
398+
399+ # Reshape to original 4D format (OC, IC, H, W)
400+ weights = weights .view (OC , IC_per_group , K_h , K_w )
386401
387402 weight_zeros = torch .zeros_like (weight_scales , dtype = torch .int32 )
403+ # Dequantize weights
404+ weights = torch .ops .quantized_decomposed .dequantize_per_channel (
405+ weights ,
406+ weight_scales ,
407+ weight_zeros ,
408+ 0 , # axis=0 for output channel quantization
409+ - 127 ,
410+ 127 ,
411+ torch .int8 ,
412+ )
388413
389- # Calculate dimensions
390- OC = weights . shape [ 0 ]
391- in_features = weights . shape [ 1 ]
392- IC = in_features // ( K_h * K_w )
414+ # Perform convolution
415+ out = torch . nn . functional . conv2d (
416+ x , weights , bias , stride , padding , dilation , groups
417+ )
393418
394- # Reshape to original 4D format (OC, IC, H, W)
395- weights = weights .view (OC , IC , K_h , K_w )
419+ out = torch .ops .quantized_decomposed .quantize_per_tensor (
420+ out , output_scale , output_zero_point , - 128 , 127 , torch .int8
421+ )
422+
423+ return out
396424
425+
426+ name = "conv2d_q8ta_q8csw_q8to"
427+ lib .define (
428+ f"""
429+ { name } (
430+ Tensor x,
431+ float input_scale,
432+ int input_zero_point,
433+ Tensor weights,
434+ Tensor weight_sums,
435+ Tensor weight_scales,
436+ float output_scale,
437+ int output_zero_point,
438+ Tensor? bias,
439+ SymInt[] kernel_size,
440+ SymInt[] stride,
441+ SymInt[] padding,
442+ SymInt[] dilation,
443+ SymInt groups) -> Tensor
444+ """
445+ )
446+ lib .impl (name , conv2d_q8ta_q8csw_q8to , "CompositeExplicitAutograd" )
447+ conv2d_q8ta_q8csw_op = getattr (getattr (torch .ops , namespace ), name )
448+
449+
450+ def conv2d_q8ta_q8csw_q8to_dw (
451+ x : torch .Tensor ,
452+ input_scale : float ,
453+ input_zero_point : int ,
454+ weights : torch .Tensor ,
455+ weight_sums : torch .Tensor ,
456+ weight_scales : torch .Tensor ,
457+ output_scale : float ,
458+ output_zero_point : int ,
459+ bias : Optional [torch .Tensor ],
460+ kernel_size : list ,
461+ stride : list ,
462+ padding : list ,
463+ dilation : list ,
464+ groups : int ,
465+ ):
466+ x = torch .ops .quantized_decomposed .dequantize_per_tensor (
467+ x , input_scale , input_zero_point , - 128 , 127 , x .dtype
468+ )
469+
470+ # Restore weight to original data layout
471+ K_h , K_w , OC = weights .shape
472+ weights = weights .permute (2 , 0 , 1 ).reshape (OC , 1 , K_h , K_w )
473+
474+ weight_zeros = torch .zeros_like (weight_scales , dtype = torch .int32 )
397475 # Dequantize weights
398476 weights = torch .ops .quantized_decomposed .dequantize_per_channel (
399477 weights ,
@@ -410,10 +488,14 @@ def conv2d_q8ta_q8csw(
410488 x , weights , bias , stride , padding , dilation , groups
411489 )
412490
491+ out = torch .ops .quantized_decomposed .quantize_per_tensor (
492+ out , output_scale , output_zero_point , - 128 , 127 , torch .int8
493+ )
494+
413495 return out
414496
415497
416- name = "conv2d_q8ta_q8csw "
498+ name = "conv2d_q8ta_q8csw_q8to_dw "
417499lib .define (
418500 f"""
419501 { name } (
@@ -423,6 +505,8 @@ def conv2d_q8ta_q8csw(
423505 Tensor weights,
424506 Tensor weight_sums,
425507 Tensor weight_scales,
508+ float output_scale,
509+ int output_zero_point,
426510 Tensor? bias,
427511 SymInt[] kernel_size,
428512 SymInt[] stride,
@@ -431,8 +515,8 @@ def conv2d_q8ta_q8csw(
431515 SymInt groups) -> Tensor
432516 """
433517)
434- lib .impl (name , conv2d_q8ta_q8csw , "CompositeExplicitAutograd" )
435- conv2d_q8ta_q8csw_op = getattr (getattr (torch .ops , namespace ), name )
518+ lib .impl (name , conv2d_q8ta_q8csw_q8to_dw , "CompositeExplicitAutograd" )
519+ conv2d_q8ta_q8csw_dw_op = getattr (getattr (torch .ops , namespace ), name )
436520
437521######################
438522## apply_rotary_emb ##
@@ -452,3 +536,39 @@ def apply_rotary_emb_impl(
452536)
453537lib .impl (name , apply_rotary_emb_impl , "CompositeExplicitAutograd" )
454538apply_rotary_emb_op = getattr (getattr (torch .ops , namespace ), name )
539+
540+ #############################
541+ ## quantize/dequantize ops ##
542+ #############################
543+
544+
545+ def quantize_q8ta_for_conv2d_impl (
546+ input : torch .Tensor ,
547+ scale : float ,
548+ zero_point : int ,
549+ ):
550+ return torch .ops .quantized_decomposed .quantize_per_tensor (
551+ input , scale , zero_point , - 128 , 127 , torch .int8
552+ )
553+
554+
555+ name = "quantize_q8ta_for_conv2d"
556+ lib .define (f"{ name } (Tensor input, float scale, int zero_point) -> Tensor" )
557+ lib .impl (name , quantize_q8ta_for_conv2d_impl , "CompositeExplicitAutograd" )
558+ quantize_q8ta_for_conv2d_op = getattr (getattr (torch .ops , namespace ), name )
559+
560+
561+ def dequantize_q8to_from_conv2d_impl (
562+ input : torch .Tensor ,
563+ scale : float ,
564+ zero_point : int ,
565+ ):
566+ return torch .ops .quantized_decomposed .dequantize_per_tensor (
567+ input , scale , zero_point , - 128 , 127 , input .dtype
568+ )
569+
570+
571+ name = "dequantize_q8to_from_conv2d"
572+ lib .define (f"{ name } (Tensor input, float scale, int zero_point) -> Tensor" )
573+ lib .impl (name , dequantize_q8to_from_conv2d_impl , "CompositeExplicitAutograd" )
574+ dequantize_q8to_from_conv2d_op = getattr (getattr (torch .ops , namespace ), name )
0 commit comments