|
8 | 8 | import torch.nn.functional as F |
9 | 9 | from torch import Tensor |
10 | 10 | from torch.nn.utils.rnn import pad_sequence |
11 | | -from transformers import PreTrainedTokenizerBase, StoppingCriteria |
| 11 | +from transformers import (DataCollatorForSeq2Seq, PreTrainedTokenizerBase, |
| 12 | + StoppingCriteria) |
12 | 13 |
|
13 | 14 | from swift.llm.agent.utils import calculate_loss_scale |
14 | 15 |
|
@@ -186,6 +187,10 @@ def _init_template(self, |
186 | 187 | self.truncation_strategy = truncation_strategy |
187 | 188 | self.model = kwargs.get('model', None) |
188 | 189 | self.use_loss_scale = kwargs.get('use_loss_scale', False) |
| 190 | + self._data_collator = DataCollatorForSeq2Seq( |
| 191 | + tokenizer=self.tokenizer, |
| 192 | + label_pad_token_id=self.tokenizer.pad_token_id, |
| 193 | + ) |
189 | 194 | for key in [ |
190 | 195 | 'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system' |
191 | 196 | ]: |
@@ -386,55 +391,28 @@ def concat_tokenizer_kwargs( |
386 | 391 | assert len(old_tokenizer_kwargs) == 0 |
387 | 392 | return curr_tokenizer_kwargs |
388 | 393 |
|
389 | | - def data_collator(self, |
390 | | - batch: List[Dict[str, Any]], |
391 | | - padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 394 | + def data_collator( |
| 395 | + self, |
| 396 | + batch: List[Dict[str, Any]], |
| 397 | + pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
392 | 398 | """ |
393 | 399 | Args: |
394 | 400 | batch(`List[Dict[str, Any]]`): The input data in batch |
395 | | - padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch |
396 | | - will be padded to the `longest` |
| 401 | + pad_to_multiple_of(`int`, optional): Whether padding to the multiple of an integer value. |
397 | 402 | """ |
398 | | - tokenizer = self.tokenizer |
399 | | - assert tokenizer.pad_token_id is not None |
400 | | - input_ids = [torch.tensor(b['input_ids']) for b in batch] |
401 | | - labels = [torch.tensor(b['labels']) for b in batch] |
402 | | - loss_scale = [torch.tensor(b['loss_scale']) |
| 403 | + self._data_collator.pad_to_multiple_of = pad_to_multiple_of |
| 404 | + if pad_to_multiple_of: |
| 405 | + self.tokenizer.padding_side = 'right' |
| 406 | + loss_scale = [torch.tensor(b.pop('loss_scale')) |
403 | 407 | for b in batch] if 'loss_scale' in batch[0] else None |
404 | | - attention_mask = [ |
405 | | - torch.ones(len(input_ids[i]), dtype=torch.int64) |
406 | | - for i in range(len(input_ids)) |
407 | | - ] |
408 | | - |
409 | | - if padding_to is not None: |
410 | | - padding_len = padding_to - input_ids[0].shape[-1] |
411 | | - if padding_len > 0: |
412 | | - input_ids[0] = F.pad(input_ids[0], (0, padding_len), |
413 | | - 'constant', tokenizer.pad_token_id) |
414 | | - attention_mask[0] = F.pad(attention_mask[0], (0, padding_len), |
415 | | - 'constant', 0) |
416 | | - labels[0] = F.pad(labels[0], (0, padding_len), 'constant', |
417 | | - -100) |
418 | | - if loss_scale: |
419 | | - loss_scale[0] = F.pad( |
420 | | - loss_scale[0], (0, padding_to - labels[0].shape[-1]), |
421 | | - 'constant', 0.) |
422 | | - |
423 | | - input_ids = pad_sequence( |
424 | | - input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
425 | | - attention_mask = pad_sequence( |
426 | | - attention_mask, batch_first=True, padding_value=0) |
| 408 | + res = self._data_collator(batch, return_tensors='pt') |
| 409 | + padding_to = res['input_ids'].shape[1] |
427 | 410 | if loss_scale: |
| 411 | + loss_scale[0] = F.pad(loss_scale[0], |
| 412 | + (0, padding_to - loss_scale[0].shape[-1]), |
| 413 | + 'constant', 0.) |
428 | 414 | loss_scale = pad_sequence( |
429 | 415 | loss_scale, batch_first=True, padding_value=0.) |
430 | | - labels = pad_sequence(labels, batch_first=True, padding_value=-100) |
431 | | - |
432 | | - res = { |
433 | | - 'input_ids': input_ids, |
434 | | - 'attention_mask': attention_mask, |
435 | | - 'labels': labels, |
436 | | - } |
437 | | - if loss_scale is not None: |
438 | 416 | res['loss_scale'] = loss_scale |
439 | 417 | return res |
440 | 418 |
|
@@ -601,10 +579,11 @@ def encode( |
601 | 579 | inputs['images'] = image_tensor.to(model.dtype) |
602 | 580 | return inputs, {} |
603 | 581 |
|
604 | | - def data_collator(self, |
605 | | - batch: List[Dict[str, Any]], |
606 | | - padding_to: Optional[int] = None) -> Dict[str, Any]: |
607 | | - res = super().data_collator(batch, padding_to) |
| 582 | + def data_collator( |
| 583 | + self, |
| 584 | + batch: List[Dict[str, Any]], |
| 585 | + pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
| 586 | + res = super().data_collator(batch, pad_to_multiple_of) |
608 | 587 | res['images'] = torch.concat([b['images'] for b in batch]) |
609 | 588 | return res |
610 | 589 |
|
@@ -908,10 +887,11 @@ def encode( |
908 | 887 | inputs['image_sizes'] = image_sizes |
909 | 888 | return inputs, {} |
910 | 889 |
|
911 | | - def data_collator(self, |
912 | | - batch: List[Dict[str, Any]], |
913 | | - padding_to: Optional[int] = None) -> Dict[str, Any]: |
914 | | - res = super().data_collator(batch, padding_to) |
| 890 | + def data_collator( |
| 891 | + self, |
| 892 | + batch: List[Dict[str, Any]], |
| 893 | + pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
| 894 | + res = super().data_collator(batch, pad_to_multiple_of) |
915 | 895 | res['images'] = torch.concat([b['images'] for b in batch]) |
916 | 896 | res['image_sizes'] = sum([b['image_sizes'] for b in batch], start=[]) |
917 | 897 | return res |
@@ -1093,10 +1073,11 @@ def encode( |
1093 | 1073 | len(inputs['input_ids']) - len(token_type_ids)) |
1094 | 1074 | return inputs, {} |
1095 | 1075 |
|
1096 | | - def data_collator(self, |
1097 | | - batch: List[Dict[str, Any]], |
1098 | | - padding_to: Optional[int] = None) -> Dict[str, Any]: |
1099 | | - res = super().data_collator(batch, padding_to) |
| 1076 | + def data_collator( |
| 1077 | + self, |
| 1078 | + batch: List[Dict[str, Any]], |
| 1079 | + pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
| 1080 | + res = super().data_collator(batch, pad_to_multiple_of) |
1100 | 1081 | is_cogagent = 'cross_images' in batch[0] |
1101 | 1082 | keys = ['images', 'cross_images'] if is_cogagent else ['images'] |
1102 | 1083 | for key in keys: |
|
0 commit comments