Skip to content

Commit 2ea885e

Browse files
authored
[train] Support new special tokens (#4945)
1 parent 9d89f70 commit 2ea885e

File tree

9 files changed

+181
-3
lines changed

9 files changed

+181
-3
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
- 🔥torch_dtype: 模型权重的数据类型,支持`float16`,`bfloat16`,`float32`。默认为None,从config.json文件中读取。
3333
- attn_impl: attention类型,可选项为`flash_attn`, `sdpa`, `eager`。默认使用None,读取config.json。
3434
- 注意:这三种实现并不一定都支持,这取决于对应模型的支持情况。
35+
- new_special_tokens: 需要新增的特殊tokens。默认为`[]`。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/new_special_tokens)
3536
- num_labels: 分类模型(即`--task_type seq_cls`)需要指定该参数。代表标签数量,默认为None。
3637
- problem_type: 分类模型(即`--task_type seq_cls`)需要指定该参数。可选为'regression', 'single_label_classification', 'multi_label_classification'。默认为None,根据num_labels和数据集类型进行自动设置。
3738
- rope_scaling: rope类型,支持`linear``dynamic``yarn`,请配合`max_length`共同使用。默认为None。
@@ -639,7 +640,7 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)
639640
## 特定模型参数
640641
特定模型参数可以通过`--model_kwargs`或者环境变量进行设置,例如: `--model_kwargs '{"fps_max_frames": 12}'`或者`FPS_MAX_FRAMES=12`
641642

642-
### qwen2_vl, qvq, qwen2_5_vl
643+
### qwen2_vl, qvq, qwen2_5_vl, mimo_vl
643644
参数含义同`qwen_vl_utils`或者`qwen_omni_utils`库,可以查看[这里](https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L24)
644645

645646
- IMAGE_FACTOR: 默认为28。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Hints:
3333
- 🔥torch_dtype: Data type of model weights, supports `float16`, `bfloat16`, `float32`. The default is None, and it is read from the 'config.json' file.
3434
- attn_impl: The type of attention, with options including `flash_attn`, `sdpa`, and `eager`. The default is None, which reads from `config.json`.
3535
- Note: These three implementations may not all be supported, depending on the support of the corresponding model.
36+
- new_special_tokens: The special tokens to be added. Default is `[]`. See the example [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/new_special_tokens).
3637
- num_labels: This parameter is required for classification models (i.e., `--task_type seq_cls`). It represents the number of labels, with a default value of None.
3738
- problem_type: This parameter is required for classification models (i.e., `--task_type seq_cls`). The options are 'regression', 'single_label_classification', and 'multi_label_classification'. The default value is None, and it will be automatically set based on the number of labels and the dataset type.
3839
- rope_scaling: Type of rope, supports `linear` and `dynamic` and `yarn`, should be used in conjunction with `max_length`. Default is None.
@@ -658,7 +659,7 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum
658659

659660
Specific model arguments can be set using `--model_kwargs` or environment variables, for example: `--model_kwargs '{"fps_max_frames": 12}'` or `FPS_MAX_FRAMES=12`.
660661

661-
### qwen2_vl, qvq, qwen2_5_vl
662+
### qwen2_vl, qvq, qwen2_5_vl, mimo_vl
662663
The parameter meanings are the same as in the `qwen_vl_utils` or `qwen_omni_utils` library. You can refer to [here](https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L24)
663664

664665
- IMAGE_FACTOR: Default is 28
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
CUDA_VISIBLE_DEVICES=0 \
2+
swift infer \
3+
--adapters output/vx-xxx/checkpoint-xxx \
4+
--max_batch_size 16 \
5+
--load_data_args true \
6+
--temperature 0
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
CUDA_VISIBLE_DEVICES=0 \
2+
swift export \
3+
--adapters output/vx-xxx/checkpoint-xxx \
4+
--merge_lora true
5+
6+
7+
# infer
8+
CUDA_VISIBLE_DEVICES=0 \
9+
swift infer \
10+
--adapters output/vx-xxx/checkpoint-xxx-merged \
11+
--max_batch_size 16 \
12+
--load_data_args true \
13+
--temperature 0
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
<|0|>
2+
<|1|>
3+
<|2|>
4+
<|3|>
5+
<|4|>
6+
<|5|>
7+
<|6|>
8+
<|7|>
9+
<|8|>
10+
<|9|>
11+
<|10|>
12+
<|11|>
13+
<|12|>
14+
<|13|>
15+
<|14|>
16+
<|15|>
17+
<|16|>
18+
<|17|>
19+
<|18|>
20+
<|19|>
21+
<|20|>
22+
<|21|>
23+
<|22|>
24+
<|23|>
25+
<|24|>
26+
<|25|>
27+
<|26|>
28+
<|27|>
29+
<|28|>
30+
<|29|>
31+
<|30|>
32+
<|31|>
33+
<|32|>
34+
<|33|>
35+
<|34|>
36+
<|35|>
37+
<|36|>
38+
<|37|>
39+
<|38|>
40+
<|39|>
41+
<|40|>
42+
<|41|>
43+
<|42|>
44+
<|43|>
45+
<|44|>
46+
<|45|>
47+
<|46|>
48+
<|47|>
49+
<|48|>
50+
<|49|>
51+
<|50|>
52+
<|51|>
53+
<|52|>
54+
<|53|>
55+
<|54|>
56+
<|55|>
57+
<|56|>
58+
<|57|>
59+
<|58|>
60+
<|59|>
61+
<|60|>
62+
<|61|>
63+
<|62|>
64+
<|63|>
65+
<|64|>
66+
<|65|>
67+
<|66|>
68+
<|67|>
69+
<|68|>
70+
<|69|>
71+
<|70|>
72+
<|71|>
73+
<|72|>
74+
<|73|>
75+
<|74|>
76+
<|75|>
77+
<|76|>
78+
<|77|>
79+
<|78|>
80+
<|79|>
81+
<|80|>
82+
<|81|>
83+
<|82|>
84+
<|83|>
85+
<|84|>
86+
<|85|>
87+
<|86|>
88+
<|87|>
89+
<|88|>
90+
<|89|>
91+
<|90|>
92+
<|91|>
93+
<|92|>
94+
<|93|>
95+
<|94|>
96+
<|95|>
97+
<|96|>
98+
<|97|>
99+
<|98|>
100+
<|99|>
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# 4 * 26GB
2+
# This example is just a demo showing how to add new_special_tokens.
3+
NPROC_PER_NODE=4 \
4+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
5+
swift sft \
6+
--model Qwen/Qwen2.5-7B-Instruct \
7+
--train_type lora \
8+
--dataset 'swift/new_special_tokens' \
9+
--split_dataset_ratio 0.01 \
10+
--new_special_tokens examples/train/new_special_tokens/tokens.txt \
11+
--torch_dtype bfloat16 \
12+
--num_train_epochs 5 \
13+
--per_device_train_batch_size 16 \
14+
--per_device_eval_batch_size 16 \
15+
--padding_free true \
16+
--attn_impl flash_attn \
17+
--learning_rate 1e-4 \
18+
--lora_rank 16 \
19+
--lora_alpha 32 \
20+
--target_modules all-linear \
21+
--modules_to_save embed_tokens lm_head \
22+
--gradient_accumulation_steps 1 \
23+
--eval_steps 500 \
24+
--save_steps 500 \
25+
--save_total_limit 2 \
26+
--logging_steps 5 \
27+
--max_length 2048 \
28+
--output_dir output \
29+
--warmup_ratio 0.05 \
30+
--dataloader_num_workers 4 \
31+
--deepspeed zero2

swift/llm/argument/base_args/base_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def load_args_from_ckpt(self) -> None:
237237
'model_revision',
238238
'torch_dtype',
239239
'attn_impl',
240+
'new_special_tokens',
240241
'num_labels',
241242
'problem_type',
242243
# quant_args

swift/llm/argument/base_args/model_args.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44
import os
55
from dataclasses import dataclass, field
6-
from typing import Any, Dict, Literal, Optional, Union
6+
from typing import Any, Dict, List, Literal, Optional, Union
77

88
import json
99
import torch
@@ -42,6 +42,7 @@ class ModelArguments:
4242
# flash_attn: It will automatically convert names based on the model.
4343
# None: It will be automatically selected between sdpa and eager.
4444
attn_impl: Literal['flash_attn', 'sdpa', 'eager', 'flex_attention', None] = None
45+
new_special_tokens: List[str] = field(default_factory=list)
4546

4647
num_labels: Optional[int] = None
4748
problem_type: Literal['regression', 'single_label_classification', 'multi_label_classification'] = None
@@ -149,9 +150,24 @@ def _init_model_info(self) -> torch.dtype:
149150
self._init_rope_scaling()
150151
return self.model_info.torch_dtype
151152

153+
def _init_new_special_tokens(self):
154+
if isinstance(self.new_special_tokens, str):
155+
self.new_special_tokens = [self.new_special_tokens]
156+
new_special_tokens = []
157+
for token in self.new_special_tokens:
158+
if token.endswith('.txt'):
159+
assert os.path.isfile(token), f'special_tokens_path: {token}'
160+
with open(token, 'r') as f:
161+
text = f.read()
162+
new_special_tokens += text.split()
163+
else:
164+
new_special_tokens.append(token)
165+
self.new_special_tokens = new_special_tokens
166+
152167
def __post_init__(self):
153168
if self.model is None:
154169
raise ValueError(f'Please set --model <model_id_or_path>`, model: {self.model}')
170+
self._init_new_special_tokens()
155171
self.model_suffix = get_model_name(self.model)
156172
self._init_device_map()
157173
self._init_max_memory()
@@ -170,6 +186,7 @@ def get_model_kwargs(self):
170186
'max_memory': self.max_memory,
171187
'quantization_config': self.get_quantization_config(),
172188
'attn_impl': self.attn_impl,
189+
'new_special_tokens': self.new_special_tokens,
173190
'rope_scaling': self.rope_scaling,
174191
'task_type': self.task_type,
175192
'num_labels': self.num_labels,

swift/llm/model/register.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def get_model_tokenizer(
559559
quantization_config=None,
560560
max_memory: Union[str, Dict[str, Any]] = None,
561561
attn_impl: Literal['flash_attn', 'sdpa', 'eager', None] = None,
562+
new_special_tokens: Optional[List[str]] = None,
562563
rope_scaling: Optional[Dict[str, Any]] = None,
563564
automodel_class=None,
564565
task_type: Literal['causal_lm', 'seq_cls', 'reranker', 'generative_reranker'] = None,
@@ -617,6 +618,13 @@ def get_model_tokenizer(
617618
patch_getattr(processor.__class__, 'tokenizer')
618619
else:
619620
tokenizer = processor
621+
if new_special_tokens:
622+
num_new_tokens = tokenizer.add_special_tokens({'additional_special_tokens': new_special_tokens})
623+
if num_new_tokens > 0:
624+
logger.info(f'Added {num_new_tokens} new special tokens.')
625+
if model.config.vocab_size < len(tokenizer):
626+
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
627+
620628
problem_type = kwargs.get('problem_type')
621629
if problem_type is None and model_info.num_labels == 1:
622630
problem_type = 'regression'

0 commit comments

Comments
 (0)