@@ -165,7 +165,7 @@ def __init__(self, **kwargs):
165165            )
166166        elif  hasattr (self .args , "use_spin_quant" ) and  self .args .use_spin_quant :
167167            print ("Using SPIN quantization." )
168-             self ._transform_for_pre_quantization (checkpoint )
168+             self ._transform_for_pre_quantization (checkpoint ,  model_args )
169169
170170            from  .source_transformation .pre_quantization  import  (
171171                sanitize_checkpoint_from_pre_quantization ,
@@ -174,8 +174,9 @@ def __init__(self, **kwargs):
174174            sanitize_checkpoint_from_pre_quantization (checkpoint )
175175        elif  hasattr (self .args , "use_qat" ) and  self .args .use_qat :
176176            print ("Using QAT quantization." )
177-             self ._transform_for_pre_quantization (checkpoint )
177+             self ._transform_for_pre_quantization (checkpoint ,  model_args )
178178            if  hasattr (self .args , "use_lora" ) and  self .args .use_lora :
179+                 assert  model_args .lora_args ["rank" ] ==  self .args .use_lora 
179180                from  .source_transformation .lora  import  (
180181                    transform_linear_for_lora_after_quantization ,
181182                )
@@ -251,7 +252,7 @@ def get_example_inputs_kvcache_sdpa(self):
251252                ),  # start_pos, what token of output are we on. 
252253            )
253254
254-     def  _transform_for_pre_quantization (self , checkpoint ):
255+     def  _transform_for_pre_quantization (self , checkpoint ,  model_args ):
255256        assert  hasattr (self .args , "preq_mode" ), "preq_mode must be specified" 
256257        assert  self .args .preq_mode  in  [
257258            "8da4w" ,
@@ -265,6 +266,8 @@ def _transform_for_pre_quantization(self, checkpoint):
265266            transform_linear_for_pre_quantization ,
266267        )
267268
269+         assert  self .args .preq_group_size  ==  model_args .quantization_args ["group_size" ]
270+ 
268271        mapping  =  {
269272            "fp32" : torch .float32 ,
270273            "fp16" : torch .float16 ,
0 commit comments