Skip to content

Commit 83cd3ea

Browse files
authored
fix some bugs in dpo (#1565)
* update * update * update * update doc * update
1 parent 65f9ef1 commit 83cd3ea

File tree

6 files changed

+67
-22
lines changed

6 files changed

+67
-22
lines changed

docs/source/LLM/自定义与拓展.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,30 @@ system,instruction,input,output
123123
{"system": "123", "query": "AAAAA", "response": "BBBBB", "rejected_response": "CCCCC", "history": [["query1", "response1"], ["query2", "response2"]]}
124124
```
125125

126-
其中`system``history`为可选项
126+
- 其中`system``history`为可选项
127127

128128
语言模型 (KTO)
129129
```jsonl
130130
{"query": "11111", "response": "22222", "label": true}
131131
{"query": "aaaaa", "response": "bbbbb", "label": false}
132132
{"system": "123", "query": "AAAAA", "response": "BBBBB", "label": true, "history": [["query1", "response1"], ["query2", "response2"]]}
133133
```
134-
注意`label`需要是bool类型, 不能是字符串
134+
- 注意`label`需要是bool类型, 不能是字符串
135135

136-
其中`system``history`为可选项
136+
- 其中`system``history`为可选项
137137

138138

139-
视觉多模态大模型, 不同模型对图像数量的支持不同, 具体参考模型对应的最佳实践文档 (DPO/ORPO/SimPO/CPO)
139+
视觉多模态大模型(DPO/ORPO/SimPO/CPO)
140+
140141
```jsonl
141142
{"system": "123", "query": "11111", "response": "22222", "rejected_response": "33333", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
142143
{"system": "123", "query": "aaaaa", "response": "bbbbb", "rejected_response": "ccccc", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
143144
{"system": "123", "query": "AAAAA", "response": "BBBBB", "rejected_response": "CCCCC", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
144145
```
145146

146-
其中`system``history`为可选项
147+
- 不同模型对图像数量的支持不同, 具体参考模型对应的最佳实践文档
148+
149+
- 其中`system``history`为可选项
147150

148151
**Tool-Calling Agent**
149152

docs/source_en/LLM/Customization.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ Language model (DPO/ORPO/SimPO/CPO)
124124
{"system": "123", "query": "aaaaa", "response": "bbbbb", "rejected_response": "ccccc", "history": [["query1", "response1"], ["query2", "response2"]]}
125125
{"system": "123", "query": "AAAAA", "response": "BBBBB", "rejected_response": "CCCCC", "history": [["query1", "response1"], ["query2", "response2"]]}
126126
```
127-
(Where system and history are optional.)
127+
- system and history are optional.
128128

129129
Language model (KTO)
130130
```jsonl
@@ -134,19 +134,20 @@ Language model (KTO)
134134
```
135135
Note: `label` needs to be of type bool, not str.
136136

137-
(Where system and history are optional.)
137+
- system and history are optional.
138138

139139

140140
Vision MLLM (DPO/ORPO/SimPO/CPO)
141141

142-
Different models have varying support for the number of images. Please refer to the corresponding best practices document for each model.
143142
```jsonl
144143
{"system": "123", "query": "11111", "response": "22222", "rejected_response": "33333", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
145144
{"system": "123", "query": "aaaaa", "response": "bbbbb", "rejected_response": "ccccc", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
146145
{"system": "123", "query": "AAAAA", "response": "BBBBB", "rejected_response": "CCCCC", "images": ["image_path"], "history": [["query1", "response1"], ["query2", "response2"]]}
147146
```
148147

149-
(Where system and history are optional.)
148+
- different models have varying support for the number of images. Please refer to the corresponding best practices document for each model.
149+
150+
- system and history are optional.
150151

151152

152153
**Tool-Calling Agent**

swift/trainers/cpo_trainer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,17 @@ def __init__(self, *args, template: Template, test_oom_error=False, **kwargs):
2121
self.template = template
2222
kwargs.pop('gamma', None)
2323
is_vision = kwargs.pop('is_vision')
24-
24+
self.keys = []
2525
super().__init__(*args, **kwargs)
26+
self.train_dataset = self.train_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
27+
if self.eval_dataset is not None:
28+
self.eval_dataset = self.eval_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
2629
train_ds_info = self.stat_dataset(self.train_dataset, self.is_encoder_decoder)
27-
val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
30+
if self.eval_dataset is not None:
31+
val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
32+
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
33+
else:
34+
self.dataset_info = {'train_dataset': train_ds_info}
2835
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
2936
if test_oom_error:
3037
self.train_dataset = sort_by_max_length(self.train_dataset, 20000)
@@ -53,6 +60,10 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
5360
prompt['response'] = None
5461
prompt_tokens = self.template.encode(prompt)[0]
5562

63+
# Skip examples that do not contain 'input_ids'
64+
if 'input_ids' not in prompt_tokens:
65+
return {k: None for k in self.keys}
66+
5667
# resolve conflict in data_collator when labels are None, pop it afterwards
5768
prompt_tokens['labels'] = prompt_tokens['input_ids']
5869
# Batching image-related information for paired response using template
@@ -170,7 +181,8 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
170181
labels=torch.tensor(batch['chosen_labels']))
171182

172183
batch.update(prompt_tokens)
173-
184+
if not self.keys:
185+
self.keys = (list(batch.keys()))
174186
return batch
175187

176188
def concatenated_forward(
@@ -216,7 +228,7 @@ def concatenated_forward(
216228
model_kwargs['output_router_logits'] = True
217229

218230
outputs = model(
219-
concatenated_batch['concatenated_input_ids'],
231+
input_ids=concatenated_batch['concatenated_input_ids'],
220232
attention_mask=concatenated_batch['concatenated_attention_mask'],
221233
use_cache=False,
222234
**model_kwargs,

swift/trainers/dpo_trainer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,18 @@ def __init__(self, *args, template: Template, sft_beta=0., test_oom_error=False,
2222
self.template = template
2323
self.sft_beta = sft_beta
2424
is_vision = kwargs.pop('is_vision')
25-
25+
self.keys = []
2626
super().__init__(*args, **kwargs)
27+
self.train_dataset = self.train_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
28+
if self.eval_dataset is not None:
29+
self.eval_dataset = self.eval_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
2730
train_ds_info = self.stat_dataset(self.train_dataset, self.is_encoder_decoder)
28-
val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
29-
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
31+
32+
if self.eval_dataset is not None:
33+
val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
34+
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
35+
else:
36+
self.dataset_info = {'train_dataset': train_ds_info}
3037
if test_oom_error:
3138
self.train_dataset = sort_by_max_length(self.train_dataset, 20000)
3239
# performance
@@ -54,6 +61,10 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
5461
prompt['response'] = None
5562
prompt_tokens = self.template.encode(prompt)[0]
5663

64+
# Skip examples that do not contain 'input_ids'
65+
if 'input_ids' not in prompt_tokens:
66+
return {k: None for k in self.keys}
67+
5768
# resolve conflict in data_collator when labels are None, pop it afterwards
5869
prompt_tokens['labels'] = prompt_tokens['input_ids']
5970
# Batching image-related information for paired response using template
@@ -171,7 +182,8 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
171182
labels=torch.tensor(batch['chosen_labels']))
172183

173184
batch.update(prompt_tokens)
174-
185+
if not self.keys:
186+
self.keys = (list(batch.keys()))
175187
return batch
176188

177189
def get_batch_loss_metrics(
@@ -289,7 +301,7 @@ def concatenated_forward(
289301
model_kwargs['output_router_logits'] = True
290302

291303
outputs = model(
292-
concatenated_batch['concatenated_input_ids'],
304+
input_ids=concatenated_batch['concatenated_input_ids'],
293305
attention_mask=concatenated_batch['concatenated_attention_mask'],
294306
use_cache=False,
295307
**model_kwargs,

swift/trainers/kto_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def __init__(self, *args, template: Template, test_oom_error=False, **kwargs):
8686
is_vision = kwargs.pop('is_vision')
8787
super().__init__(*args, **kwargs)
8888
train_ds_info = self.stat_dataset(self.train_dataset)
89-
val_ds_info = self.stat_dataset(self.eval_dataset)
89+
if self.eval_dataset is not None:
90+
val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
91+
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
92+
else:
93+
self.dataset_info = {'train_dataset': train_ds_info}
9094
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
9195
if test_oom_error:
9296
self.train_dataset = sort_by_max_length(self.train_dataset, 20000)

swift/trainers/orpo_trainer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,17 @@ class ORPOTrainer(PushToMsHubMixin, SwiftMixin, HFORPOTrainer):
2020
def __init__(self, *args, template: Template, test_oom_error=False, **kwargs):
2121
self.template = template
2222
is_vision = kwargs.pop('is_vision')
23+
self.keys = []
2324
super().__init__(*args, **kwargs)
25+
self.train_dataset = self.train_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
26+
if self.eval_dataset is not None:
27+
self.eval_dataset = self.eval_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
2428
train_ds_info = self.stat_dataset(self.train_dataset, self.is_encoder_decoder)
25-
val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
29+
if self.eval_dataset is not None:
30+
val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
31+
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
32+
else:
33+
self.dataset_info = {'train_dataset': train_ds_info}
2634
self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
2735
if test_oom_error:
2836
self.train_dataset = sort_by_max_length(self.train_dataset, 20000)
@@ -51,6 +59,10 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
5159
prompt['response'] = None
5260
prompt_tokens = self.template.encode(prompt)[0]
5361

62+
# Skip examples that do not contain 'input_ids'
63+
if 'input_ids' not in prompt_tokens:
64+
return {k: None for k in self.keys}
65+
5466
# resolve conflict in data_collator when labels are None, pop it afterwards
5567
prompt_tokens['labels'] = prompt_tokens['input_ids']
5668
# Batching image-related information for paired response using template
@@ -168,7 +180,8 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
168180
labels=torch.tensor(batch['chosen_labels']))
169181

170182
batch.update(prompt_tokens)
171-
183+
if not self.keys:
184+
self.keys = (list(batch.keys()))
172185
return batch
173186

174187
def concatenated_forward(
@@ -214,7 +227,7 @@ def concatenated_forward(
214227
model_kwargs['output_router_logits'] = True
215228

216229
outputs = model(
217-
concatenated_batch['concatenated_input_ids'],
230+
input_ids=concatenated_batch['concatenated_input_ids'],
218231
attention_mask=concatenated_batch['concatenated_attention_mask'],
219232
use_cache=False,
220233
**model_kwargs,

0 commit comments

Comments
 (0)