Skip to content

Commit c8a53f8

Browse files
committed
Merge commit 'ab712f13dbf219374b96bca8effedfdca760b6a2' into release/1.6
* commit 'ab712f13dbf219374b96bca8effedfdca760b6a2': fix system='' bug (#378) fix system='' bug (#374) update compute loss (#375) fix loss (#372) Fix length penalty (#371) fix lazy_tokenize bug (#369)
2 parents e43d416 + ab712f1 commit c8a53f8

File tree

12 files changed

+75
-47
lines changed

12 files changed

+75
-47
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
- `--dataset_test_ratio`: 用于指定子数据集切分成训练集和验证集的比例, 默认为`0.01`. 如果子数据集已经进行了训练集和验证集的切分, 则此参数无效.
2828
- `--train_dataset_sample`: 对训练集进行采样, 默认是`20000`, 用于加快训练的速度. 该参数是为了避免数据集过大, 单个epoch训练时间过长的问题. 如果你指定为`-1`, 则使用完整的训练集进行训练.
2929
- `--val_dataset_sample`: 对验证集进行采样, 默认是`None`, 自动选取合适数量的数据集数量进行验证. 如果你指定为`-1`, 则使用完整的验证集进行验证.
30-
- `--system`: 对话模板中使用的system, 默认为`None`, 即使用模型默认的system.
30+
- `--system`: 对话模板中使用的system, 默认为`None`, 即使用模型默认的system. 如果指定为'', 则不使用system.
3131
- `--max_length`: token的最大长度, 默认为`2048`. 可以避免个别过长的数据样本造成OOM的问题. 当指定`--truncation_strategy delete`时, 如果某数据样本长度超过max_length, 我们会删除该数据样本. 如果指定`--truncation_strategy truncation_left`时, 我们会切除最前面的token: `input_ids[-max_length:]`. 如果设置为-1, 则无限制.
3232
- `--truncation_strategy`: 默认是`'delete'`表示把超过max_length的句子从数据集中删除. `'truncation_left'`表示会将超过文本的左边给切除掉, 这可能会切到special token, 会影响性能, 并不推荐.
3333
- `--check_dataset_strategy`: 默认值为`'none'`, 即不做检查. 如果你训练的模型是LLM, 则推荐使用`'warning'`作为数据检查的策略. 如果你的训练目标为句子分类等任务, 则建议设置为'`none`'.

docs/source/LLM/自定义与拓展.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ class CustomTemplateType:
317317
register_template(
318318
CustomTemplateType.tigerbot,
319319
Template(['{{SYSTEM}}'], ['\n\n### Instruction:\n{{QUERY}}\n\n### Response:\n'], [],
320-
[['eos_token_id']], ''))
320+
[['eos_token_id']]))
321321

322322
if __name__ == '__main__':
323323
# test template

examples/pytorch/llm/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_tigerbot_model_tokenizer(model_dir: str,
7272
CustomTemplateType.tigerbot,
7373
Template(['{{SYSTEM}}'],
7474
['\n\n### Instruction:\n{{QUERY}}\n\n### Response:\n'], [],
75-
[['eos_token_id']], ''))
75+
[['eos_token_id']]))
7676

7777

7878
def _preprocess_stsb(dataset: HfDataset) -> HfDataset:

swift/llm/deploy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
104104
return create_error_response(HTTPStatus.BAD_REQUEST, error_msg)
105105
kwargs = {'max_new_tokens': request.max_tokens}
106106
for key in [
107-
'n', 'stop', 'best_of', 'frequency_penalty', 'presence_penalty',
108-
'num_beams'
107+
'n', 'stop', 'best_of', 'frequency_penalty', 'length_penalty',
108+
'presence_penalty', 'num_beams'
109109
]:
110110
kwargs[key] = getattr(request, key)
111111
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:

swift/llm/sft.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
162162
td0, tkwargs0 = template.encode(train_dataset[0])
163163
print_example(td0, tokenizer, tkwargs0)
164164
train_dataset = LazyLLMDataset(train_dataset, template)
165-
val_dataset = LazyLLMDataset(val_dataset, template)
165+
if val_dataset is not None:
166+
val_dataset = LazyLLMDataset(val_dataset, template)
166167

167168
padding_to = args.max_length if args.sft_type == 'longlora' else None
168169
data_collator = partial(template.data_collator, padding_to=padding_to)

swift/llm/utils/argument.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,8 @@ def __post_init__(self) -> None:
522522
self.infer_backend = 'vllm'
523523
if self.infer_backend == 'vllm':
524524
assert self.quantization_bit == 0, 'VLLM does not support bnb.'
525-
assert support_vllm, f'vllm not support `{self.model_type}`'
525+
if not support_vllm:
526+
logger.warning(f'vllm not support `{self.model_type}`')
526527
if self.sft_type == 'lora':
527528
assert self.merge_lora_and_save is True, (
528529
'To use VLLM, you need to provide the complete weight parameters. '

swift/llm/utils/protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class XRequestConfig:
4545
best_of: Optional[int] = None
4646
presence_penalty: float = 0.
4747
frequency_penalty: float = 0.
48+
length_penalty: float = 1.
4849

4950
# additional
5051
num_beams: int = 1

swift/llm/utils/template.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def _has_system(prefix: Prompt) -> bool:
104104
return False
105105

106106

107+
def _replace_system(prefix: Prompt) -> Prompt:
108+
res = []
109+
for p in prefix:
110+
if '{{SYSTEM}}' in p:
111+
p = p.replace('{{SYSTEM}}', '')
112+
res.append(p)
113+
return res
114+
115+
107116
class Template:
108117

109118
def __init__(self,
@@ -113,11 +122,13 @@ def __init__(self,
113122
suffix: Prompt,
114123
default_system: Optional[str] = None,
115124
prefix_has_system: Optional[Prompt] = None) -> None:
116-
self.prefix = prefix
125+
if default_system == '':
126+
default_system = None
117127
if _has_system(prefix):
118128
assert prefix_has_system is None, 'The prefix already contains {{SYSTEM}}.'
119-
assert default_system is not None, 'You need to provide the `default_system`.'
120129
prefix_has_system = prefix
130+
prefix = _replace_system(prefix)
131+
self.prefix = prefix
121132
self.prefix_has_system = prefix_has_system
122133
if self.prefix_has_system is None:
123134
assert default_system is None, 'The template does not support `system`.'
@@ -157,7 +168,10 @@ def _init_template(self,
157168
assert self._is_init is False, 'The template has been initialized.'
158169
self._is_init = True
159170
self.tokenizer = tokenizer
160-
if default_system is not None:
171+
# if default_system is None. not change self.default_system
172+
if default_system == '':
173+
self.default_system = None
174+
elif default_system is not None:
161175
assert self.prefix_has_system is not None, 'The template does not support `system`.'
162176
self.default_system = default_system
163177
self.max_length = max_length
@@ -189,6 +203,8 @@ def encode(
189203
if system is None:
190204
if self.use_default_system:
191205
system = self.default_system
206+
elif system == '':
207+
system = None
192208
else:
193209
assert self.prefix_has_system is not None, 'The template does not support `system`.'
194210
inputs, tokenizer_kwargs = self._encode(query, response, history,
@@ -299,7 +315,6 @@ def _encode(
299315
res_context_list: List[Context] = []
300316
compute_loss_idx: List[float] = []
301317
if system is None:
302-
assert self.prefix != self.prefix_has_system, f'template.prefix: {self.prefix}'
303318
prefix = self.prefix
304319
else:
305320
prefix = self.prefix_has_system
@@ -586,22 +601,21 @@ def data_collator(self,
586601

587602
register_template(
588603
TemplateType.yi_vl,
589-
YiVLTemplate(['{{SYSTEM}}\n\n'],
590-
['### Human: ', [-200], '\n{{QUERY}}\n### Assistant:\n'],
591-
['\n'], ['\n###'], yi_vl_default_system),
604+
YiVLTemplate([], ['### Human: ', [-200], '\n{{QUERY}}\n### Assistant:\n'],
605+
['\n'], ['\n###'], yi_vl_default_system, ['{{SYSTEM}}\n\n']),
592606
use_model=True,
593607
infer_media_type='round',
594608
lazy_tokenize=True)
595609

596610
register_template(
597611
TemplateType.baichuan,
598612
Template(['{{SYSTEM}}'], [[195], '{{QUERY}}', [196]], [],
599-
[['eos_token_id']], ''))
613+
[['eos_token_id']]))
600614
register_template(
601615
TemplateType.chatglm2,
602616
Template([[64790, 64792], '{{SYSTEM}}'],
603617
['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'], ['\n\n'],
604-
[['eos_token_id']], ''))
618+
[['eos_token_id']]))
605619

606620
register_template(
607621
TemplateType.chatglm_generation,
@@ -818,29 +832,29 @@ def get_generate_ids(generate_ids: Tensor,
818832
register_template(
819833
TemplateType.xverse,
820834
Template(['{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: '],
821-
[['eos_token_id']], [['eos_token_id']], ''))
835+
[['eos_token_id']], [['eos_token_id']]))
822836
register_template(TemplateType.yuan,
823837
Template([], ['{{QUERY}}<sep>'], None, [['eos_token_id']]))
824838
register_template(
825839
TemplateType.ziya,
826840
Template([['bos_token_id'], '{{SYSTEM}}'], ['<human>:{{QUERY}}\n<bot>:'],
827-
['\n'], [['eos_token_id']], ''))
841+
['\n'], [['eos_token_id']]))
828842

829843
register_template(
830844
TemplateType.skywork,
831845
Template(['<s>{{SYSTEM}}'], ['</s><s>[USER]{{QUERY}}[SEP][BOT]'], None,
832-
['[SEP]</s>'], ''))
846+
['[SEP]</s>']))
833847

834848
register_template(
835849
TemplateType.bluelm,
836850
Template([['bos_token_id'], '{{SYSTEM}}'], ['[|Human|]:{{QUERY}}[|AI|]:'],
837-
[], [['eos_token_id']], ''))
851+
[], [['eos_token_id']]))
838852

839853
register_template(
840854
TemplateType.codefuse_codellama,
841855
Template(['{{SYSTEM}}'], [
842856
'<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'
843-
], [], [['eos_token_id']], ''))
857+
], [], [['eos_token_id']]))
844858

845859
register_template(
846860
TemplateType.codefuse,
@@ -867,12 +881,12 @@ def get_generate_ids(generate_ids: Tensor,
867881
register_template(
868882
TemplateType.sus,
869883
Template(['{{SYSTEM}}'], ['### Human: {{QUERY}}\n\n### Assistant: '],
870-
['<|endoftext|>'], ['<|endoftext|>'], ''))
884+
['<|endoftext|>'], ['<|endoftext|>']))
871885

872886
register_template(
873887
TemplateType.orion,
874888
Template(['<s>{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: </s>'],
875-
['</s>'], ['</s>'], ''))
889+
['</s>'], ['</s>']))
876890

877891

878892
class CogAgentTemplate(Template):
@@ -939,7 +953,7 @@ def data_collator(self,
939953

940954
register_template(
941955
TemplateType.openbmb,
942-
Template(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>'], ''))
956+
Template(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>']))
943957

944958

945959
def get_template(

swift/llm/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _try_fetch(self, first_idx: int) -> Optional[Dict[str, Any]]:
205205
for i in [first_idx] + idx.tolist():
206206
data = self.dataset[i]
207207
res = self.template.encode(data)
208-
if res is not None:
208+
if len(res[0]) > 0:
209209
return res
210210

211211
def __len__(self) -> int:

swift/llm/utils/vllm_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def __init__(
139139
kwargs['top_p'] = top_p
140140
kwargs['repetition_penalty'] = repetition_penalty
141141
if num_beams > 1:
142-
assert 'use_beam_search' not in kwargs and 'best_of' not in kwargs
142+
best_of = kwargs.get('best_of')
143+
assert 'use_beam_search' not in kwargs and best_of is None
143144
kwargs['use_beam_search'] = True
144145
kwargs['best_of'] = num_beams
145146
kwargs['n'] = n

0 commit comments

Comments
 (0)