@@ -81,6 +81,22 @@ def _get_dataset(*args, **kwargs):
8181 return res
8282
8383
84+ @contextmanager
85+ def _patch_move_embed (awq_model ):
86+ _origin_move_embed = awq_model .move_embed
87+
88+ def _move_embed (model , device : str ):
89+ if hasattr (model , '_hf_hook' ) and device != 'cpu' :
90+ return
91+ _origin_move_embed (model , device )
92+
93+ awq_model .move_embed = _move_embed
94+ try :
95+ yield
96+ finally :
97+ awq_model .move_embed = _origin_move_embed
98+
99+
84100def awq_model_quantize (awq_model , tokenizer , batch_size ) -> None :
85101
86102 from awq .quantize import quantizer
@@ -93,7 +109,8 @@ def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
93109 group_size = 128
94110 quant_config = {'zero_point' : True , 'q_group_size' : group_size , 'w_bit' : _args .quant_bits , 'version' : 'GEMM' }
95111 logger .info ('Start quantizing the model...' )
96- awq_model .quantize (tokenizer , quant_config = quant_config , n_parallel_calib_samples = batch_size )
112+ with _patch_move_embed (awq_model ):
113+ awq_model .quantize (tokenizer , quant_config = quant_config , n_parallel_calib_samples = batch_size )
97114 quantizer .get_calib_dataset = _origin_get_calib_dataset # recover
98115 awq_model .model .config .quantization_config = AwqConfig (
99116 bits = _args .quant_bits , group_size = group_size , zero_point = True , version = 'GEMM' )
@@ -260,6 +277,7 @@ def llm_export(args: ExportArguments) -> None:
260277 from awq import AutoAWQForCausalLM
261278 model , template = prepare_model_template (
262279 args , device_map = args .quant_device_map , task = 'export' , automodel_class = AutoAWQForCausalLM )
280+ template .model = model .model
263281 awq_model_quantize (model , template .tokenizer , args .quant_batch_size )
264282 model .save_quantized (args .quant_output_dir )
265283 elif args .quant_method == 'gptq' :
0 commit comments