Skip to content

Commit e457716

Browse files
authored
Fix glm4v rlhf (#1745)
* update * update * update * lint
1 parent 98bf327 commit e457716

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

swift/trainers/cpo_trainer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Tuple, Union
1+
from typing import Any, Dict, List, Optional, Tuple, Union
22

33
import torch
44
from torch import nn
@@ -204,7 +204,8 @@ def concatenated_forward(
204204
)
205205

206206
if self.is_vision_model:
207-
concatenated_batch = self.concatenated_vision_inputs(batch, concatenated_batch)
207+
concatenated_batch = self.concatenated_vision_inputs(
208+
batch, concatenated_batch, device=self.accelerator.device)
208209

209210
len_chosen = batch['chosen_labels'].shape[0]
210211

@@ -216,8 +217,9 @@ def concatenated_forward(
216217
} if self.is_encoder_decoder else {})
217218

218219
if self.is_vision_model:
220+
# Here, we restore the _data, processing image information within the forward hook of the model.
219221
batch_size = concatenated_batch['concatenated_input_ids'].shape[0]
220-
if self._data_keys is not None:
222+
if self._data_keys:
221223
_data = [dict() for _ in range(batch_size)]
222224
for k in self._data_keys:
223225
if k == 'input_ids':
@@ -231,6 +233,9 @@ def concatenated_forward(
231233
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
232234
model_kwargs['_data'] = _data
233235

236+
if 'images' in concatenated_batch:
237+
model_kwargs['images'] = concatenated_batch['images']
238+
234239
if self.aux_loss_enabled:
235240
model_kwargs['output_router_logits'] = True
236241

@@ -292,6 +297,7 @@ def cross_entropy_loss(logits, labels):
292297
def concatenated_vision_inputs(
293298
batch: Dict[str, Union[List, torch.LongTensor]],
294299
concatenated_batch: Dict[str, torch.LongTensor],
300+
device: Optional[torch.device] = None,
295301
) -> Dict[str, torch.LongTensor]:
296302
if 'prompt_pixel_values' in batch:
297303
pixel_values = [values for values in batch['prompt_pixel_values']]
@@ -308,6 +314,9 @@ def concatenated_vision_inputs(
308314
if 'prompt_image_sizes' in batch:
309315
concatenated_batch['image_sizes'] = batch['prompt_image_sizes']
310316

317+
if 'prompt_images' in batch:
318+
# images not in _data, we manually execute data collector here
319+
concatenated_batch['images'] = batch['prompt_images'].squeeze(1).repeat(2, 1, 1, 1).to(device=device)
311320
return concatenated_batch
312321

313322
@staticmethod

swift/trainers/dpo_trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def concatenated_forward(
294294
if self.is_vision_model:
295295
# Here, we restore the _data, processing image information within the forward hook of the model.
296296
batch_size = concatenated_batch['concatenated_input_ids'].shape[0]
297-
if self._data_keys is not None:
297+
if self._data_keys:
298298
_data = [dict() for _ in range(batch_size)]
299299
for k in self._data_keys:
300300
if k == 'input_ids':
@@ -306,7 +306,10 @@ def concatenated_forward(
306306
_data = [{**d, k: concatenated_batch[k][i // 2].to(model_dtype)} for i, d in enumerate(_data)]
307307
else:
308308
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
309-
model_kwargs['_data'] = _data
309+
model_kwargs['_data'] = _data
310+
311+
if 'images' in concatenated_batch:
312+
model_kwargs['images'] = concatenated_batch['images']
310313

311314
if self.aux_loss_enabled:
312315
model_kwargs['output_router_logits'] = True
@@ -427,9 +430,8 @@ def concatenated_inputs(
427430
batch['prompt_attention_mask'].repeat(2, 1).to(device=device))
428431

429432
# patch here
430-
# leave data collector in hook
431-
432433
if is_vision_model:
434+
# for keys appear in _data, we leave data collector in hook
433435
if 'prompt_pixel_values' in batch:
434436
pixel_values = [values for values in batch['prompt_pixel_values']]
435437
concatenated_batch['pixel_values'] = pixel_values
@@ -445,6 +447,9 @@ def concatenated_inputs(
445447
if 'prompt_image_sizes' in batch:
446448
concatenated_batch['image_sizes'] = batch['prompt_image_sizes']
447449

450+
if 'prompt_images' in batch:
451+
# images not in _data, we manually execute data collector here
452+
concatenated_batch['images'] = batch['prompt_images'].squeeze(1).repeat(2, 1, 1, 1).to(device=device)
448453
return concatenated_batch
449454

450455
@staticmethod

swift/trainers/orpo_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Tuple, Union
1+
from typing import Any, Dict, List, Optional, Tuple, Union
22

33
import torch
44
from torch import nn
@@ -215,8 +215,9 @@ def concatenated_forward(
215215
} if self.is_encoder_decoder else {})
216216

217217
if self.is_vision_model:
218+
# Here, we restore the _data, processing image information within the forward hook of the model.
218219
batch_size = concatenated_batch['concatenated_input_ids'].shape[0]
219-
if self._data_keys is not None:
220+
if self._data_keys:
220221
_data = [dict() for _ in range(batch_size)]
221222
for k in self._data_keys:
222223
if k == 'input_ids':
@@ -230,6 +231,9 @@ def concatenated_forward(
230231
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
231232
model_kwargs['_data'] = _data
232233

234+
if 'images' in concatenated_batch:
235+
model_kwargs['images'] = concatenated_batch['images']
236+
233237
if self.aux_loss_enabled:
234238
model_kwargs['output_router_logits'] = True
235239

@@ -293,6 +297,7 @@ def cross_entropy_loss(logits, labels):
293297
def concatenated_vision_inputs(
294298
batch: Dict[str, Union[List, torch.LongTensor]],
295299
concatenated_batch: Dict[str, torch.LongTensor],
300+
device: Optional[torch.device] = None,
296301
) -> Dict[str, torch.LongTensor]:
297302
if 'prompt_pixel_values' in batch:
298303
pixel_values = [values for values in batch['prompt_pixel_values']]
@@ -309,6 +314,9 @@ def concatenated_vision_inputs(
309314
if 'prompt_image_sizes' in batch:
310315
concatenated_batch['image_sizes'] = batch['prompt_image_sizes']
311316

317+
if 'prompt_images' in batch:
318+
# images not in _data, we manually execute data collector here
319+
concatenated_batch['images'] = batch['prompt_images'].squeeze(1).repeat(2, 1, 1, 1).to(device=device)
312320
return concatenated_batch
313321

314322
@staticmethod

swift/trainers/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def patch_datacollator():
151151
def new_call(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
152152
padded_batch = {}
153153
for k in features[0].keys():
154-
if k.endswith(('_input_ids', '_attention_mask', '_labels', '_pixel_values')):
154+
if k.endswith(('_input_ids', '_attention_mask', '_labels', '_pixel_values', '_images')):
155155
if self.is_encoder_decoder:
156156
to_pad = [torch.LongTensor(ex[k]) for ex in features]
157157

@@ -187,7 +187,7 @@ def new_call(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
187187
padding_value = self.label_pad_token_id
188188
elif k.endswith('_attention_mask'):
189189
padding_value = 0
190-
elif k.endswith('_pixel_values'):
190+
elif k.endswith(('_pixel_values', '_images')):
191191
padding_value = 0
192192
else:
193193
raise ValueError(f"Unexpected key in batch '{k}'")
@@ -199,7 +199,7 @@ def new_call(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
199199
padding_side = 'right'
200200

201201
# Set the dtype
202-
if k.endswith('_pixel_values'):
202+
if k.endswith(('_pixel_values', '_images')):
203203
dtype = torch.float32 # will be downcasted if necessary by the Trainer
204204
else:
205205
dtype = torch.int64

0 commit comments

Comments
 (0)