@@ -191,73 +191,31 @@ def __init__(self, **kwargs):
191191 )
192192 elif hasattr (self .args , "use_spin_quant" ) and self .args .use_spin_quant :
193193 print ("Using SPIN quantization." )
194- assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
195- assert self .args .preq_mode in [
196- "8da4w" ,
197- "8da4w_output_8da8w" ,
198- ], f"Quantization mode { self .args .preq_mode } is not compatible with SpinQuant."
199- assert hasattr (
200- self .args , "preq_group_size"
201- ), "preq_group_size must be specified"
202- assert hasattr (
203- self .args , "dtype_override"
204- ), "dtype_override must be specified"
194+ self ._transform_for_pre_quantization (checkpoint )
195+
205196 from .source_transformation .pre_quantization import (
206197 sanitize_checkpoint_from_pre_quantization ,
207- transform_linear_for_pre_quantization ,
208- )
209-
210- mapping = {
211- "fp32" : torch .float32 ,
212- "fp16" : torch .float16 ,
213- "bf16" : torch .bfloat16 ,
214- }
215-
216- # Transform the output layer first if needed.
217- if self .args .preq_mode == "8da4w_output_8da8w" :
218- from .source_transformation .pre_quantization import (
219- transform_output_linear_for_pre_quantization ,
220- )
221-
222- self .model_ = transform_output_linear_for_pre_quantization (
223- module = self .model_ ,
224- checkpoint = checkpoint ,
225- dtype = mapping [self .args .dtype_override ],
226- )
227-
228- self .model_ = transform_linear_for_pre_quantization (
229- self .model_ ,
230- checkpoint ,
231- self .args .preq_group_size ,
232- mapping [self .args .dtype_override ],
233198 )
234199
235- embedding_bit_width , embedding_group_size = None , None
236- if hasattr (self .args , "preq_embedding_quantize" ) :
237- embedding_bit_width , embedding_group_size = (
238- self .args . preq_embedding_quantize . split ( "," )
239- )
240- from .source_transformation .pre_quantization import (
241- transform_embedding_for_pre_quantization ,
200+ sanitize_checkpoint_from_pre_quantization ( checkpoint )
201+ elif hasattr (self .args , "use_qat" ) and self . args . use_qat :
202+ print ( "Using QAT quantization." )
203+ self ._transform_for_pre_quantization ( checkpoint )
204+ if hasattr ( self . args , "use_lora" ) and self . args . use_lora :
205+ from .source_transformation .lora import (
206+ transform_linear_for_lora_after_quantization ,
242207 )
243208
244- if (
245- embedding_group_size == "none"
246- or embedding_group_size == "None"
247- or embedding_group_size == "0"
248- ):
249- embedding_group_size = None
250- else :
251- embedding_group_size = int (embedding_group_size )
252-
253- self .model_ = transform_embedding_for_pre_quantization (
209+ self .model_ = transform_linear_for_lora_after_quantization (
254210 self .model_ ,
255211 checkpoint ,
256- mapping [self .args .dtype_override ],
257- int (embedding_bit_width ),
258- embedding_group_size ,
212+ self .args .use_lora ,
259213 )
260214
215+ from .source_transformation .pre_quantization import (
216+ sanitize_checkpoint_from_pre_quantization ,
217+ )
218+
261219 sanitize_checkpoint_from_pre_quantization (checkpoint )
262220
263221 # assign=True: load params/buffers by assignment instead of performing an in-place copy.
@@ -318,3 +276,68 @@ def get_example_inputs_kvcache_sdpa(self):
318276 [0 ], dtype = torch .long
319277 ), # start_pos, what token of output are we on.
320278 )
279+
280+ def _transform_for_pre_quantization (self , checkpoint ):
281+ assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
282+ assert self .args .preq_mode in [
283+ "8da4w" ,
284+ "8da4w_output_8da8w" ,
285+ ], f"Quantization mode { self .args .preq_mode } is not compatible with SpinQuant."
286+ assert hasattr (
287+ self .args , "preq_group_size"
288+ ), "preq_group_size must be specified"
289+ assert hasattr (self .args , "dtype_override" ), "dtype_override must be specified"
290+ from .source_transformation .pre_quantization import (
291+ transform_linear_for_pre_quantization ,
292+ )
293+
294+ mapping = {
295+ "fp32" : torch .float32 ,
296+ "fp16" : torch .float16 ,
297+ "bf16" : torch .bfloat16 ,
298+ }
299+
300+ # Transform the output layer first if needed.
301+ if self .args .preq_mode == "8da4w_output_8da8w" :
302+ from .source_transformation .pre_quantization import (
303+ transform_output_linear_for_pre_quantization ,
304+ )
305+
306+ self .model_ = transform_output_linear_for_pre_quantization (
307+ module = self .model_ ,
308+ checkpoint = checkpoint ,
309+ dtype = mapping [self .args .dtype_override ],
310+ )
311+
312+ self .model_ = transform_linear_for_pre_quantization (
313+ self .model_ ,
314+ checkpoint ,
315+ self .args .preq_group_size ,
316+ mapping [self .args .dtype_override ],
317+ )
318+
319+ embedding_bit_width , embedding_group_size = None , None
320+ if hasattr (self .args , "preq_embedding_quantize" ):
321+ embedding_bit_width , embedding_group_size = (
322+ self .args .preq_embedding_quantize .split ("," )
323+ )
324+ from .source_transformation .pre_quantization import (
325+ transform_embedding_for_pre_quantization ,
326+ )
327+
328+ if (
329+ embedding_group_size == "none"
330+ or embedding_group_size == "None"
331+ or embedding_group_size == "0"
332+ ):
333+ embedding_group_size = None
334+ else :
335+ embedding_group_size = int (embedding_group_size )
336+
337+ self .model_ = transform_embedding_for_pre_quantization (
338+ self .model_ ,
339+ checkpoint ,
340+ mapping [self .args .dtype_override ],
341+ int (embedding_bit_width ),
342+ embedding_group_size ,
343+ )
0 commit comments