Skip to content

Commit 9873fdc

Browse files
committed
fix fp16 & full bug (#203)
1 parent a30d235 commit 9873fdc

File tree

12 files changed

+80
-21
lines changed

12 files changed

+80
-21
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用
120120
- bluelm series: [bluelm-7b](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base/summary), [bluelm-7b-chat](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat/summary), [bluelm-7b-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base-32K/summary), [bluelm-7b-chat-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat-32K/summary)
121121
- mistral series: [mistral-7b](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1/summary), [mistral-7b-chat](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)
122122
- yi series: [yi-6b](https://modelscope.cn/models/01ai/Yi-6B/summary), [yi-34b](https://modelscope.cn/models/01ai/Yi-34B/summary), [yi-34b-chat](https://modelscope.cn/models/01ai/Yi-34B-Chat/summary)
123-
- zephyr series: zephyr-7b-beta-chat(https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
123+
- zephyr series: [zephyr-7b-beta-chat](https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
124124
- ziya series: [ziya2-13b](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Base/summary), [ziya2-13b-chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
125125
- skywork series: [skywork-13b](https://modelscope.cn/models/skywork/Skywork-13B-base/summary), [skywork-13b-chat](https://modelscope.cn/models/skywork/Skywork-13B-chat/summary)
126126
- other: [polylm-13b](https://modelscope.cn/models/damo/nlp_polylm_13b_text_generation/summary), [seqgpt-560m](https://modelscope.cn/models/damo/nlp_seqgpt-560m/summary)

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
118118
- bluelm 系列: [bluelm-7b](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base/summary), [bluelm-7b-chat](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat/summary), [bluelm-7b-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base-32K/summary), [bluelm-7b-chat-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat-32K/summary)
119119
- mistral 系列: [mistral-7b](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1/summary), [mistral-7b-chat](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)
120120
- yi 系列: [yi-6b](https://modelscope.cn/models/01ai/Yi-6B/summary), [yi-34b](https://modelscope.cn/models/01ai/Yi-34B/summary), [yi-34b-chat](https://modelscope.cn/models/01ai/Yi-34B-Chat/summary)
121-
- zephyr 系列: zephyr-7b-beta-chat(https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
121+
- zephyr 系列: [zephyr-7b-beta-chat](https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
122122
- ziya 系列: [ziya2-13b](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Base/summary), [ziya2-13b-chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
123123
- skywork 系列: [skywork-13b](https://modelscope.cn/models/skywork/Skywork-13B-base/summary), [skywork-13b-chat](https://modelscope.cn/models/skywork/Skywork-13B-chat/summary)
124124
- other: [polylm-13b](https://modelscope.cn/models/damo/nlp_polylm_13b_text_generation/summary), [seqgpt-560m](https://modelscope.cn/models/damo/nlp_seqgpt-560m/summary)

swift/llm/infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from swift.utils import (append_to_jsonl, get_logger, print_model_info,
1515
read_multi_line, seed_everything, show_layers)
1616
from .utils import (InferArguments, Template, get_dataset, get_model_tokenizer,
17-
get_template, inference, inference_stream)
17+
get_template, inference, inference_stream,
18+
set_generation_config)
1819

1920
logger = get_logger()
2021

@@ -141,7 +142,7 @@ def prepare_model_template(
141142
pad_token_id=tokenizer.pad_token_id,
142143
eos_token_id=tokenizer.eos_token_id)
143144
logger.info(f'generation_config: {generation_config}')
144-
model.generation_config = generation_config
145+
set_generation_config(model, generation_config)
145146
return model, template
146147

147148

swift/llm/rome.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
show_layers)
99
from ..tuners.rome import RomeConfig
1010
from .utils import (RomeArguments, Template, get_dataset, get_model_tokenizer,
11-
get_template, inference)
11+
get_template, inference, set_generation_config)
1212

1313
logger = get_logger()
1414

@@ -72,7 +72,7 @@ def rome_infer(args: RomeArguments) -> None:
7272
logger.info(f'generation_config: {generation_config}')
7373
if args.overwrite_generation_config:
7474
generation_config.save_pretrained(args.ckpt_dir)
75-
model.generation_config = generation_config
75+
set_generation_config(model, generation_config)
7676

7777
# Inference
7878
if args.eval_human:

swift/llm/sft.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
seed_everything, show_layers)
1919
from .utils import (SftArguments, Template, add_self_cognition_dataset,
2020
data_collate_fn, dataset_map, find_all_linear_for_lora,
21-
get_dataset, get_model_tokenizer, get_template,
22-
print_example, sort_by_max_length, stat_dataset)
21+
get_additional_saved_files, get_dataset,
22+
get_model_tokenizer, get_template, print_example,
23+
set_generation_config, sort_by_max_length, stat_dataset)
2324

2425
logger = get_logger()
2526

@@ -182,11 +183,15 @@ def llm_sft(args: SftArguments) -> str:
182183
pad_token_id=tokenizer.pad_token_id,
183184
eos_token_id=tokenizer.eos_token_id)
184185
logger.info(f'generation_config: {generation_config}')
186+
set_generation_config(model, generation_config)
185187
evaluation_strategy = IntervalStrategy.STEPS
186188
load_best_model_at_end = True
187189
if val_dataset is None:
188190
evaluation_strategy = IntervalStrategy.NO
189191
load_best_model_at_end = False
192+
additional_saved_files = []
193+
if args.sft_type == 'full':
194+
additional_saved_files = get_additional_saved_files(args.model_type)
190195
training_args = Seq2SeqTrainingArguments(
191196
output_dir=args.output_dir,
192197
evaluation_strategy=evaluation_strategy,
@@ -230,7 +235,8 @@ def llm_sft(args: SftArguments) -> str:
230235
only_save_model=args.only_save_model,
231236
train_sampler_random=args.train_sampler_random,
232237
report_to=args.report_to,
233-
deepspeed=args.deepspeed)
238+
deepspeed=args.deepspeed,
239+
additional_saved_files=additional_saved_files)
234240

235241
if args.gradient_checkpointing:
236242
model.enable_input_require_grads()

swift/llm/utils/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
get_dataset_from_repo, load_dataset_from_local,
66
load_ms_dataset, register_dataset)
77
from .model import (MODEL_MAPPING, GetModelTokenizerFunction, LoRATM,
8-
ModelType, get_default_lora_target_modules,
9-
get_default_template_type, get_model_tokenizer,
10-
get_model_tokenizer_from_repo,
8+
ModelType, get_additional_saved_files,
9+
get_default_lora_target_modules, get_default_template_type,
10+
get_model_tokenizer, get_model_tokenizer_from_repo,
1111
get_model_tokenizer_from_sdk, register_model)
1212
from .preprocess import (AlpacaPreprocessor, ClsPreprocessor,
1313
ComposePreprocessor, ConversationsPreprocessor,
@@ -19,5 +19,5 @@
1919
from .utils import (data_collate_fn, dataset_map, download_dataset,
2020
find_all_linear_for_lora, history_to_messages, inference,
2121
inference_stream, limit_history_length,
22-
messages_to_history, print_example, sort_by_max_length,
23-
stat_dataset)
22+
messages_to_history, print_example, set_generation_config,
23+
sort_by_max_length, stat_dataset)

swift/llm/utils/argument.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
from swift.hub import HubApi, ModelScopeConfig
1515
from swift.utils import (add_version_to_work_dir, broadcast_string,
1616
get_dist_setting, is_dist, is_master)
17-
from .dataset import (DATASET_MAPPING, DatasetName, get_custom_dataset,
18-
register_dataset)
19-
from .model import (MODEL_MAPPING, ModelType, dtype_mapping,
17+
from .dataset import DATASET_MAPPING, get_custom_dataset, register_dataset
18+
from .model import (MODEL_MAPPING, dtype_mapping,
2019
get_default_lora_target_modules, get_default_template_type)
2120
from .template import TEMPLATE_MAPPING, TemplateType
2221

@@ -431,8 +430,13 @@ def select_dtype(
431430
assert torch_dtype in {torch.float16, torch.bfloat16, torch.float32}
432431
if torch_dtype == torch.float16:
433432
if isinstance(args, SftArguments) and args.sft_type == 'full':
433+
args.dtype = 'fp32'
434434
torch_dtype = torch.float32
435-
logger.warning('Setting torch_dtype: torch.float32')
435+
logger.warning(
436+
'Fine-tuning with full parameters does not support fp16, and is prone to NaN. '
437+
'We will use the fp32 & AMP approach, which consumes approximately twice the memory of bf16.'
438+
)
439+
logger.info(f'Setting torch_dtype: {torch_dtype}')
436440
fp16, bf16 = True, False
437441
elif torch_dtype == torch.bfloat16:
438442
support_bf16 = torch.cuda.is_bf16_supported()

swift/llm/utils/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,13 @@ def get_model_tokenizer_qwen_vl(model_dir: str,
827827
]
828828
get_qwen_function = kwargs.pop('get_qwen_function',
829829
get_model_tokenizer_qwen_chat)
830+
tokenizer_config = get_tokenizer_config(model_dir)
831+
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
832+
tokenizer_cls = get_class_from_dynamic_module(class_ref, model_dir)
833+
tokenizer_cls._auto_class = 'AutoTokenizer'
834+
tokenizer_cls.IMAGE_ST = () # fix no attr `self.IMAGE_ST` bug
835+
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(
836+
model_dir, trust_remote_code=True)
830837
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
831838
load_model, **kwargs)
832839
if model is not None:
@@ -870,6 +877,13 @@ def get_model_tokenizer_qwen_audio(model_dir: str,
870877
load_model: bool = True,
871878
**kwargs):
872879
get_qwen_function = kwargs.pop('get_qwen_function')
880+
tokenizer_config = get_tokenizer_config(model_dir)
881+
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
882+
tokenizer_cls = get_class_from_dynamic_module(class_ref, model_dir)
883+
tokenizer_cls._auto_class = 'AutoTokenizer'
884+
tokenizer_cls.AUDIO_ST = () # fix no attr `self.AUDIO_ST` bug
885+
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(
886+
model_dir, trust_remote_code=True)
873887
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
874888
load_model, **kwargs)
875889
if model is not None:
@@ -1148,6 +1162,14 @@ def get_model_tokenizer(
11481162
return model, tokenizer
11491163

11501164

1165+
def get_additional_saved_files(model_type: str) -> List[str]:
1166+
if 'qwen-vl' in model_type:
1167+
return ['SimSun.ttf']
1168+
elif 'qwen-audio' in model_type:
1169+
return ['mel_filters.npz']
1170+
return []
1171+
1172+
11511173
def get_default_template_type(model_type: str) -> Optional[str]:
11521174
return MODEL_MAPPING[model_type].get('template')
11531175

swift/llm/utils/template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from copy import deepcopy
3-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
3+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
44

55
from torch import Tensor
66
from transformers import PreTrainedTokenizerBase, StoppingCriteria

swift/llm/utils/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
from torch.nn.utils.rnn import pad_sequence
2929
from torch.utils.data import Dataset
3030
from tqdm.auto import tqdm
31-
from transformers import (PreTrainedModel, PreTrainedTokenizerBase,
32-
StoppingCriteriaList, TextStreamer, trainer)
31+
from transformers import (GenerationConfig, PreTrainedModel,
32+
PreTrainedTokenizerBase, StoppingCriteriaList,
33+
TextStreamer, trainer)
3334

3435
from swift.hub import ModelScopeConfig
3536
from swift.utils import (get_dist_setting, get_logger, is_ddp_plus_mp, is_dist,
@@ -540,6 +541,16 @@ def messages_to_history(messages: Messages) -> Dict[str, Any]:
540541
}
541542

542543

544+
def set_generation_config(model: Module,
545+
generation_config: GenerationConfig) -> None:
546+
if hasattr(model, 'generation_config'):
547+
old_generation_config = model.generation_config
548+
for k, v in old_generation_config.__dict__.items():
549+
if k not in generation_config.__dict__:
550+
setattr(generation_config, k, v)
551+
model.geneartion_config = generation_config
552+
553+
543554
# monkey patching
544555
MsDataset.load = _msdataset_ddp_load
545556
if is_ddp_plus_mp():

0 commit comments

Comments
 (0)