@@ -693,53 +693,56 @@ def _get_kernel_with_merged_lora(self):
693
693
`kernel_scale`: The quantization scale for the merged kernel.
694
694
This is `None` if the layer is not quantized.
695
695
"""
696
- if self .dtype_policy .quantization_mode is not None :
697
- kernel_value = self ._kernel
698
- kernel_scale = self .kernel_scale
699
- if self .lora_enabled :
700
- # Dequantize kernel to float
701
- if self .quantization_mode == "int4" :
702
- unpacked_kernel = quantizers .unpack_int4 (
703
- kernel_value , self ._orig_input_dim
704
- )
705
- float_kernel = ops .divide (
706
- ops .cast (unpacked_kernel , self .compute_dtype ),
707
- kernel_scale ,
708
- )
709
- quant_range = (- 8 , 7 )
710
- elif self .quantization_mode == "int8" :
711
- float_kernel = ops .divide (
712
- ops .cast (kernel_value , self .compute_dtype ), kernel_scale
713
- )
714
- quant_range = (- 127 , 127 )
715
- else :
716
- raise ValueError (
717
- "Unsupported quantization mode: "
718
- f"{ self .quantization_mode } "
719
- )
720
-
721
- # Merge LoRA weights in float domain
722
- lora_delta = (self .lora_alpha / self .lora_rank ) * ops .matmul (
723
- self .lora_kernel_a , self .lora_kernel_b
724
- )
725
- merged_float_kernel = ops .add (float_kernel , lora_delta )
726
-
727
- # Requantize
728
- requantized_kernel , kernel_scale = quantizers .abs_max_quantize (
729
- merged_float_kernel ,
730
- axis = 0 ,
731
- value_range = quant_range ,
732
- dtype = "int8" ,
733
- to_numpy = True ,
734
- )
735
- kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
736
-
737
- # Pack if int4
738
- if self .quantization_mode == "int4" :
739
- kernel_value , _ , _ = quantizers .pack_int4 (
740
- requantized_kernel
741
- )
742
- else :
743
- kernel_value = requantized_kernel
696
+ if self .dtype_policy .quantization_mode is None :
697
+ return self .kernel , None
698
+
699
+ kernel_value = self ._kernel
700
+ kernel_scale = self .kernel_scale
701
+
702
+ if not self .lora_enabled :
744
703
return kernel_value , kernel_scale
745
- return self .kernel , None
704
+
705
+ # Dequantize, Merge, and Re-quantize
706
+
707
+ # Dequantize kernel to float
708
+ if self .quantization_mode == "int4" :
709
+ unpacked_kernel = quantizers .unpack_int4 (
710
+ kernel_value , self ._orig_input_dim
711
+ )
712
+ float_kernel = ops .divide (
713
+ ops .cast (unpacked_kernel , self .compute_dtype ),
714
+ kernel_scale ,
715
+ )
716
+ quant_range = (- 8 , 7 )
717
+ elif self .quantization_mode == "int8" :
718
+ float_kernel = ops .divide (
719
+ ops .cast (kernel_value , self .compute_dtype ), kernel_scale
720
+ )
721
+ quant_range = (- 127 , 127 )
722
+ else :
723
+ raise ValueError (
724
+ f"Unsupported quantization mode: { self .quantization_mode } "
725
+ )
726
+
727
+ # Merge LoRA weights in float domain
728
+ lora_delta = (self .lora_alpha / self .lora_rank ) * ops .matmul (
729
+ self .lora_kernel_a , self .lora_kernel_b
730
+ )
731
+ merged_float_kernel = ops .add (float_kernel , lora_delta )
732
+
733
+ # Requantize
734
+ requantized_kernel , kernel_scale = quantizers .abs_max_quantize (
735
+ merged_float_kernel ,
736
+ axis = 0 ,
737
+ value_range = quant_range ,
738
+ dtype = "int8" ,
739
+ to_numpy = True ,
740
+ )
741
+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
742
+
743
+ # Pack if int4
744
+ if self .quantization_mode == "int4" :
745
+ kernel_value , _ , _ = quantizers .pack_int4 (requantized_kernel )
746
+ else :
747
+ kernel_value = requantized_kernel
748
+ return kernel_value , kernel_scale
0 commit comments