@@ -165,7 +165,7 @@ def __init__(self, **kwargs):
165165 )
166166 elif hasattr (self .args , "use_spin_quant" ) and self .args .use_spin_quant :
167167 print ("Using SPIN quantization." )
168- self ._transform_for_pre_quantization (checkpoint )
168+ self ._transform_for_pre_quantization (checkpoint , model_args )
169169
170170 from .source_transformation .pre_quantization import (
171171 sanitize_checkpoint_from_pre_quantization ,
@@ -174,8 +174,9 @@ def __init__(self, **kwargs):
174174 sanitize_checkpoint_from_pre_quantization (checkpoint )
175175 elif hasattr (self .args , "use_qat" ) and self .args .use_qat :
176176 print ("Using QAT quantization." )
177- self ._transform_for_pre_quantization (checkpoint )
177+ self ._transform_for_pre_quantization (checkpoint , model_args )
178178 if hasattr (self .args , "use_lora" ) and self .args .use_lora :
179+ assert model_args .lora_args ["rank" ] == self .args .use_lora
179180 from .source_transformation .lora import (
180181 transform_linear_for_lora_after_quantization ,
181182 )
@@ -251,7 +252,7 @@ def get_example_inputs_kvcache_sdpa(self):
251252 ), # start_pos, what token of output are we on.
252253 )
253254
254- def _transform_for_pre_quantization (self , checkpoint ):
255+ def _transform_for_pre_quantization (self , checkpoint , model_args ):
255256 assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
256257 assert self .args .preq_mode in [
257258 "8da4w" ,
@@ -265,6 +266,8 @@ def _transform_for_pre_quantization(self, checkpoint):
265266 transform_linear_for_pre_quantization ,
266267 )
267268
269+ assert self .args .preq_group_size == model_args .quantization_args ["group_size" ]
270+
268271 mapping = {
269272 "fp32" : torch .float32 ,
270273 "fp16" : torch .float16 ,
0 commit comments