Skip to content

Commit 12e731a

Browse files
authored
Removing eos_token when doing inference. (#351)
1 parent ace8b89 commit 12e731a

File tree

9 files changed

+65
-30
lines changed

9 files changed

+65
-30
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用
6262

6363

6464
## 🎉 News
65-
- 2024.1.26: Support [yi-vl-6b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_vl_6b_chat), yi-vl-34b-chat.
65+
- 2024.1.29: Support internlm2-math series: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
66+
- 🔥2024.1.26: Support [yi-vl-6b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_vl_6b_chat), yi-vl-34b-chat.
6667
- 2024.1.24: Support codefuse-codegeex2-6b-chat, codefuse-qwen-14b-chat.
6768
- 2024.1.23: Support orion series: orion-14b, [orion-14b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/orion_14b_chat).
6869
- 2024.1.20: Support [xverse-13b-256k](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/xverse_13b_256k), xverse-65b-v2, xverse-65b-chat.
@@ -164,7 +165,7 @@ from swift.llm import (
164165
infer_main, sft_main, app_ui_main, merge_lora_main
165166
)
166167

167-
model_type = ModelType.qwen_1_8b_chat
168+
model_type = ModelType.qwen_1_8b
168169
sft_args = SftArguments(
169170
model_type=model_type,
170171
train_dataset_sample=2000,
@@ -178,7 +179,7 @@ torch.cuda.empty_cache()
178179
infer_args = InferArguments(
179180
ckpt_dir=best_model_checkpoint,
180181
load_dataset_config=True,
181-
show_dataset_sample=10)
182+
val_dataset_sample=10)
182183
# merge_lora_main(infer_args)
183184
result = infer_main(infer_args)
184185
torch.cuda.empty_cache()
@@ -222,6 +223,8 @@ app_ui_main(infer_args)
222223
- [deepseek-coder](https://github.com/deepseek-ai/DeepSeek-Coder) series: deepseek-coder-1_3b, deepseek-coder-1_3b-instruct, deepseek-coder-6_7b, deepseek-coder-6_7b-instruct, deepseek-coder-33b, deepseek-coder-33b-instruct.
223224
- [codegeex2](https://github.com/THUDM/CodeGeeX2) series: codegeex2-6b.
224225
- [phi](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) series: phi2-3b.
226+
- Math:
227+
- [internlm2-math](https://github.com/InternLM/InternLM-Math) series: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
225228
- Supported Datasets: [[Detailed Info]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
226229
- NLP:
227230
- General: 🔥alpaca-en(gpt4), 🔥alpaca-zh(gpt4), multi-alpaca-all, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, instruct-en, gpt4all-en, sharegpt-en, sharegpt-zh, tutu-v2-sft-mixture, wikipedia-zh, open-orca, open-orca-gpt4, sharegpt-gpt4.

README_CN.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
6060
用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。
6161

6262
## 🎉 新闻
63+
- 2024.1.29: 支持internlm2-math系列: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
6364
- 2024.1.26: 支持[yi-vl-6b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_vl_6b_chat), yi-vl-34b-chat.
6465
- 2024.1.24: 支持codefuse-codegeex2-6b-chat, codefuse-qwen-14b-chat.
6566
- 2024.1.23: 支持orion系列: orion-14b, [orion-14b-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/orion_14b_chat).
@@ -164,7 +165,7 @@ from swift.llm import (
164165
infer_main, sft_main, app_ui_main, merge_lora_main
165166
)
166167

167-
model_type = ModelType.qwen_1_8b_chat
168+
model_type = ModelType.qwen_1_8b
168169
sft_args = SftArguments(
169170
model_type=model_type,
170171
train_dataset_sample=2000,
@@ -178,7 +179,7 @@ torch.cuda.empty_cache()
178179
infer_args = InferArguments(
179180
ckpt_dir=best_model_checkpoint,
180181
load_dataset_config=True,
181-
show_dataset_sample=10)
182+
val_dataset_sample=10)
182183
# merge_lora_main(infer_args)
183184
result = infer_main(infer_args)
184185
torch.cuda.empty_cache()
@@ -222,6 +223,8 @@ app_ui_main(infer_args)
222223
- [deepseek-coder](https://github.com/deepseek-ai/DeepSeek-Coder) 系列: deepseek-coder-1_3b, deepseek-coder-1_3b-instruct, deepseek-coder-6_7b, deepseek-coder-6_7b-instruct, deepseek-coder-33b, deepseek-coder-33b-instruct.
223224
- [codegeex2](https://github.com/THUDM/CodeGeeX2) 系列: codegeex2-6b.
224225
- [phi](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) 系列: phi2-3b.
226+
- 数学:
227+
- [internlm2-math](https://github.com/InternLM/InternLM-Math) 系列: internlm2-math-7b, internlm2-math-7b-chat, internlm2-math-20b, internlm2-math-20b-chat.
225228
- 支持的数据集: [[详细信息]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
226229
- NLP:
227230
- 通用: 🔥alpaca-en(gpt4), 🔥alpaca-zh(gpt4), multi-alpaca-all, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, instruct-en, gpt4all-en, sharegpt-en, sharegpt-zh, tutu-v2-sft-mixture, wikipedia-zh, open-orca, open-orca-gpt4, sharegpt-gpt4.

docs/source/LLM/LLM微调文档.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ torch.cuda.empty_cache()
6464
infer_args = InferArguments(
6565
ckpt_dir=best_model_checkpoint,
6666
load_dataset_config=True,
67-
show_dataset_sample=10)
67+
val_dataset_sample=10)
6868
# merge_lora_main(infer_args)
6969
result = infer_main(infer_args)
7070
torch.cuda.empty_cache()

docs/source/LLM/支持的模型和数据集.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@
6868
|internlm2-20b|[Shanghai_AI_Laboratory/internlm2-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-20b/summary)|wqkv|default-generation-bos|✔|✘||
6969
|internlm2-20b-sft-chat|[Shanghai_AI_Laboratory/internlm2-chat-20b-sft](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-20b-sft/summary)|wqkv|internlm2|✔|✘||
7070
|internlm2-20b-chat|[Shanghai_AI_Laboratory/internlm2-chat-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-20b/summary)|wqkv|internlm2|✔|✘||
71+
|internlm2-math-7b|[Shanghai_AI_Laboratory/internlm2-math-base-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-base-7b/summary)|wqkv|default-generation-bos|✔|✘||
72+
|internlm2-math-7b-chat|[Shanghai_AI_Laboratory/internlm2-math-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-7b/summary)|wqkv|internlm2|✔|✘||
73+
|internlm2-math-20b|[Shanghai_AI_Laboratory/internlm2-math-base-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-base-20b/summary)|wqkv|default-generation-bos|✔|✘||
74+
|internlm2-math-20b-chat|[Shanghai_AI_Laboratory/internlm2-math-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-math-20b/summary)|wqkv|internlm2|✔|✘||
7175
|deepseek-7b|[deepseek-ai/deepseek-llm-7b-base](https://modelscope.cn/models/deepseek-ai/deepseek-llm-7b-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✔||
7276
|deepseek-7b-chat|[deepseek-ai/deepseek-llm-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-llm-7b-chat/summary)|q_proj, k_proj, v_proj|deepseek|✔|✔||
7377
|deepseek-moe-16b|[deepseek-ai/deepseek-moe-16b-base](https://modelscope.cn/models/deepseek-ai/deepseek-moe-16b-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✘||

scripts/utils/run_model_info.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,14 @@ def get_model_info_readme_zh(data: List[str]) -> None:
5656
model_list = []
5757
for match in match_list:
5858
model_list += match[2].strip('.').split(',')
59-
model_list = [model.strip() for model in model_list]
59+
model_list_2 = []
60+
for model in model_list:
61+
model = model.strip()
62+
model_match = re.search(r'\[(.+)\]\(.+\)', model)
63+
if model_match is not None:
64+
model = model_match.group(1)
65+
model_list_2.append(model)
66+
model_list = model_list_2
6067
model_type_list = [d[0] for d in data]
6168
print(set(model_type_list) - set(model_list))
6269
print(set(model_list) - set(model_type_list))

swift/llm/utils/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ class ModelType:
9898
internlm2_20b_sft_chat = 'internlm2-20b-sft-chat'
9999
internlm2_20b_chat = 'internlm2-20b-chat'
100100
# internlm2-math
101-
internlm2_math_7b_chat = 'internlm2-math-7b-chat'
102101
internlm2_math_7b = 'internlm2-math-7b'
103-
internlm2_math_20b_chat = 'internlm2-math-20b-chat'
102+
internlm2_math_7b_chat = 'internlm2-math-7b-chat'
104103
internlm2_math_20b = 'internlm2-math-20b'
104+
internlm2_math_20b_chat = 'internlm2-math-20b-chat'
105105
# deepseek
106106
deepseek_7b = 'deepseek-7b'
107107
deepseek_7b_chat = 'deepseek-7b-chat'

swift/llm/utils/template.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,9 @@ def __call__(self, input_ids: Tensor, scores: Tensor) -> bool:
8686
if isinstance(stop_word, str):
8787
if stop_word in text:
8888
return True
89-
elif isinstance(stop_word, list) and len(stop_word) > 0:
90-
res = []
91-
for sw in stop_word:
92-
if isinstance(sw, str):
93-
token = getattr(tokenizer, sw)
94-
assert token is not None
95-
else:
96-
token = sw
97-
res.append(token)
98-
if input_ids[0].tolist()[-len(res):] == res:
89+
else: # list
90+
if len(stop_word) > 0 and input_ids[0].tolist(
91+
)[-len(stop_word):] == stop_word:
9992
return True
10093
return False
10194

@@ -132,6 +125,24 @@ def __init__(self,
132125
self.use_default_system = True
133126
self._is_init = False
134127

128+
@staticmethod
129+
def _preprocess_prompt(tokenizer: PreTrainedTokenizerBase,
130+
value: Optional[Prompt]) -> Optional[Prompt]:
131+
# e.g. [['eos_token_id']] -> [[2]]
132+
if value is None:
133+
return None
134+
res_value = []
135+
for v in value:
136+
if isinstance(v, list):
137+
res_v = []
138+
for sub_v in v:
139+
if isinstance(sub_v, str):
140+
sub_v = getattr(tokenizer, sub_v)
141+
res_v.append(sub_v)
142+
v = res_v
143+
res_value.append(v)
144+
return res_value
145+
135146
def _init_template(self,
136147
tokenizer: PreTrainedTokenizerBase,
137148
default_system: Optional[str] = None,
@@ -148,6 +159,10 @@ def _init_template(self,
148159
self.max_length = max_length
149160
self.truncation_strategy = truncation_strategy
150161
self.model = kwargs.get('model', None)
162+
for key in ['prefix', 'prompt', 'chat_sep', 'suffix']:
163+
value = getattr(self, key)
164+
value = self._preprocess_prompt(tokenizer, value)
165+
setattr(self, key, value)
151166

152167
def encode(
153168
self, example: Dict[str,
@@ -254,14 +269,7 @@ def _encode_context_list(
254269
add_special_tokens=False,
255270
**curr_tokenizer_kwargs)['input_ids']
256271
else:
257-
token_list = []
258-
for c in context:
259-
if isinstance(c, str):
260-
token = getattr(tokenizer, c)
261-
assert token is not None
262-
else:
263-
token = c
264-
token_list.append(token)
272+
token_list = context
265273
input_ids += token_list
266274
if i in compute_loss_idx:
267275
labels += token_list

swift/llm/utils/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ def inference_stream(model: PreTrainedModel,
441441
stream_config.eos_token_id = tokenizer.eos_token_id
442442
if tokenizer.pad_token_id is not None:
443443
stream_config.pad_token_id = tokenizer.pad_token_id
444+
if tokenizer.bos_token_id is not None:
445+
stream_config.bos_token_id = tokenizer.bos_token_id
444446
if stream_config.max_new_tokens is not None:
445447
stream_config.max_length = 20 # fix max_length, max_new_tokens warning
446448
stream_config.do_sample = True # avoid is_greedy_gen_mode = True
@@ -568,6 +570,8 @@ def inference(model: PreTrainedModel,
568570
generation_config.eos_token_id = tokenizer.eos_token_id
569571
if tokenizer.pad_token_id is not None:
570572
generation_config.pad_token_id = tokenizer.pad_token_id
573+
if tokenizer.bos_token_id is not None:
574+
generation_config.bos_token_id = tokenizer.bos_token_id
571575
if generation_config.max_new_tokens is not None:
572576
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
573577
if template.suffix[-1] not in stop_words:

tests/llm/test_run.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ def test_cogagent_instruct(self):
251251
torch.cuda.empty_cache()
252252
infer_main(
253253
InferArguments(
254-
ckpt_dir=best_model_checkpoint, load_dataset_config=True))
254+
ckpt_dir=best_model_checkpoint,
255+
load_dataset_config=True,
256+
val_dataset_sample=2))
255257

256258
def test_yi_vl_6b_chat(self):
257259
if not __name__ == '__main__':
@@ -272,7 +274,9 @@ def test_yi_vl_6b_chat(self):
272274
torch.cuda.empty_cache()
273275
infer_main(
274276
InferArguments(
275-
ckpt_dir=best_model_checkpoint, load_dataset_config=True))
277+
ckpt_dir=best_model_checkpoint,
278+
load_dataset_config=True,
279+
val_dataset_sample=2))
276280

277281
def test_dpo(self):
278282
if not __name__ == '__main__':
@@ -288,7 +292,9 @@ def test_dpo(self):
288292
torch.cuda.empty_cache()
289293
infer_main(
290294
InferArguments(
291-
ckpt_dir=best_model_checkpoint, load_dataset_config=True))
295+
ckpt_dir=best_model_checkpoint,
296+
load_dataset_config=True,
297+
val_dataset_sample=2))
292298

293299

294300
def data_collate_fn(batch: List[Dict[str, Any]],

0 commit comments

Comments
 (0)