@@ -255,6 +255,21 @@ def get_model_tokenizer_from_local(model_dir: str,
255255 InitModelStrategy .init_parameters (model , init_strategy )
256256
257257 model_info .config = model_config if model is None else model .config
258+
259+ pad_token = tokenizer .pad_token_id
260+ if pad_token is None :
261+ pad_token = tokenizer .eos_token_id
262+ if tokenizer .eos_token_id is None :
263+ tokenizer .eos_token_id = pad_token
264+ if tokenizer .pad_token_id is None :
265+ tokenizer .pad_token_id = pad_token
266+ assert tokenizer .eos_token_id is not None
267+ assert tokenizer .pad_token_id is not None
268+
269+ if model is not None :
270+ # fix seq classification task
271+ HfConfigFactory .set_model_config_attr (model , 'pad_token_id' , pad_token )
272+
258273 return model , tokenizer
259274
260275
@@ -583,20 +598,7 @@ def get_model_tokenizer(
583598 tokenizer .model_info = model_info
584599 tokenizer .model_meta = model_meta
585600
586- pad_token = tokenizer .pad_token_id
587- if pad_token is None :
588- pad_token = tokenizer .eos_token_id
589- if tokenizer .eos_token_id is None :
590- tokenizer .eos_token_id = pad_token
591- if tokenizer .pad_token_id is None :
592- tokenizer .pad_token_id = pad_token
593- assert tokenizer .eos_token_id is not None
594- assert tokenizer .pad_token_id is not None
595-
596601 if model is not None :
597- # fix seq classification task
598- HfConfigFactory .set_model_config_attr (model , 'pad_token_id' , pad_token )
599-
600602 model .model_info = model_info
601603 model .model_meta = model_meta
602604 model .model_dir = model_dir
0 commit comments