@@ -109,7 +109,7 @@ def merge_lora(args: InferArguments,
109109 if device_map is None :
110110 device_map = args .merge_device_map
111111 logger .info (f'merge_device_map: { device_map } ' )
112- model , template = prepare_model_template (args , device_map = device_map , verbose = False )
112+ model , template = prepare_model_template (args , device_map = device_map , task = 'export' )
113113 logger .info ('Merge LoRA...' )
114114 Swift .merge_and_unload (model )
115115 model = model .model
@@ -133,7 +133,7 @@ def merge_lora(args: InferArguments,
133133def prepare_model_template (args : InferArguments ,
134134 * ,
135135 device_map : Optional [str ] = None ,
136- verbose : bool = True ,
136+ task : Literal [ 'infer' , 'export' ] = 'infer' ,
137137 automodel_class = None ) -> Tuple [PreTrainedModel , Template ]:
138138 from .sft import get_default_device_map
139139 if is_torch_npu_available ():
@@ -188,25 +188,7 @@ def prepare_model_template(args: InferArguments,
188188 revision = args .model_revision ,
189189 quant_method = args .quant_method ,
190190 ** kwargs )
191- if verbose :
192- logger .info (f'model_config: { model .config } ' )
193-
194- generation_config = GenerationConfig (
195- max_new_tokens = args .max_new_tokens ,
196- temperature = args .temperature ,
197- top_k = args .top_k ,
198- top_p = args .top_p ,
199- do_sample = args .do_sample ,
200- repetition_penalty = args .repetition_penalty ,
201- num_beams = args .num_beams ,
202- pad_token_id = tokenizer .pad_token_id ,
203- eos_token_id = tokenizer .eos_token_id )
204- set_generation_config (model , generation_config )
205- logger .info (f'model.generation_config: { model .generation_config } ' )
206191
207- if model .generation_config .num_beams != 1 :
208- args .stream = False
209- logger .info ('Setting args.stream: False' )
210192 if model .max_model_len is None :
211193 model .max_model_len = args .max_model_len
212194 elif args .max_model_len is not None :
@@ -215,6 +197,26 @@ def prepare_model_template(args: InferArguments,
215197 else :
216198 raise ValueError ('args.max_model_len exceeds the maximum max_model_len supported by the model.'
217199 f'args.max_model_len: { args .max_model_len } , model.max_model_len: { model .max_model_len } ' )
200+ if task == 'infer' :
201+ logger .info (f'model_config: { model .config } ' )
202+ generation_config = GenerationConfig (
203+ max_new_tokens = args .max_new_tokens ,
204+ temperature = args .temperature ,
205+ top_k = args .top_k ,
206+ top_p = args .top_p ,
207+ do_sample = args .do_sample ,
208+ repetition_penalty = args .repetition_penalty ,
209+ num_beams = args .num_beams ,
210+ pad_token_id = tokenizer .pad_token_id ,
211+ eos_token_id = tokenizer .eos_token_id )
212+ model ._generation_config_origin = model .generation_config
213+ set_generation_config (model , generation_config )
214+ logger .info (f'model.generation_config: { model .generation_config } ' )
215+
216+ if model .generation_config .num_beams != 1 :
217+ args .stream = False
218+ logger .info ('Setting args.stream: False' )
219+
218220 # Preparing LoRA
219221 if is_adapter (args .sft_type ) and args .ckpt_dir is not None :
220222 if isinstance (args , DeployArguments ) and args .lora_request_list is not None :
@@ -227,7 +229,7 @@ def prepare_model_template(args: InferArguments,
227229 model = model .to (model .dtype )
228230 model .requires_grad_ (False )
229231
230- if verbose :
232+ if task == 'infer' :
231233 show_layers (model )
232234 logger .info (model )
233235 logger .info (get_model_info (model ))
0 commit comments