|
8 | 8 | import torch.nn.functional as F |
9 | 9 | from torch import Tensor |
10 | 10 | from torch.nn.utils.rnn import pad_sequence |
11 | | -from transformers import (DataCollatorForSeq2Seq, PreTrainedTokenizerBase, |
12 | | - StoppingCriteria) |
| 11 | +from transformers import PreTrainedTokenizerBase, StoppingCriteria |
13 | 12 |
|
14 | 13 | from swift.llm.agent.utils import calculate_loss_scale |
15 | 14 |
|
@@ -187,10 +186,6 @@ def _init_template(self, |
187 | 186 | self.truncation_strategy = truncation_strategy |
188 | 187 | self.model = kwargs.get('model', None) |
189 | 188 | self.use_loss_scale = kwargs.get('use_loss_scale', False) |
190 | | - self._data_collator = DataCollatorForSeq2Seq( |
191 | | - tokenizer=self.tokenizer, |
192 | | - label_pad_token_id=self.tokenizer.pad_token_id, |
193 | | - ) |
194 | 189 | for key in [ |
195 | 190 | 'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system' |
196 | 191 | ]: |
@@ -391,28 +386,55 @@ def concat_tokenizer_kwargs( |
391 | 386 | assert len(old_tokenizer_kwargs) == 0 |
392 | 387 | return curr_tokenizer_kwargs |
393 | 388 |
|
394 | | - def data_collator( |
395 | | - self, |
396 | | - batch: List[Dict[str, Any]], |
397 | | - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
| 389 | + def data_collator(self, |
| 390 | + batch: List[Dict[str, Any]], |
| 391 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
398 | 392 | """ |
399 | 393 | Args: |
400 | 394 | batch(`List[Dict[str, Any]]`): The input data in batch |
401 | | - pad_to_multiple_of(`int`, optional): Whether padding to the multiple of an integer value. |
| 395 | + padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch |
| 396 | + will be padded to the `longest` |
402 | 397 | """ |
403 | | - self._data_collator.pad_to_multiple_of = pad_to_multiple_of |
404 | | - if pad_to_multiple_of: |
405 | | - self.tokenizer.padding_side = 'right' |
406 | | - loss_scale = [torch.tensor(b.pop('loss_scale')) |
| 398 | + tokenizer = self.tokenizer |
| 399 | + assert tokenizer.pad_token_id is not None |
| 400 | + input_ids = [torch.tensor(b['input_ids']) for b in batch] |
| 401 | + labels = [torch.tensor(b['labels']) for b in batch] |
| 402 | + loss_scale = [torch.tensor(b['loss_scale']) |
407 | 403 | for b in batch] if 'loss_scale' in batch[0] else None |
408 | | - res = self._data_collator(batch, return_tensors='pt') |
409 | | - padding_to = res['input_ids'].shape[1] |
| 404 | + attention_mask = [ |
| 405 | + torch.ones(len(input_ids[i]), dtype=torch.int64) |
| 406 | + for i in range(len(input_ids)) |
| 407 | + ] |
| 408 | + |
| 409 | + if padding_to is not None: |
| 410 | + padding_len = padding_to - input_ids[0].shape[-1] |
| 411 | + if padding_len > 0: |
| 412 | + input_ids[0] = F.pad(input_ids[0], (0, padding_len), |
| 413 | + 'constant', tokenizer.pad_token_id) |
| 414 | + attention_mask[0] = F.pad(attention_mask[0], (0, padding_len), |
| 415 | + 'constant', 0) |
| 416 | + labels[0] = F.pad(labels[0], (0, padding_len), 'constant', |
| 417 | + -100) |
410 | 418 | if loss_scale: |
411 | 419 | loss_scale[0] = F.pad(loss_scale[0], |
412 | | - (0, padding_to - loss_scale[0].shape[-1]), |
| 420 | + (0, padding_to - labels[0].shape[-1]), |
413 | 421 | 'constant', 0.) |
| 422 | + |
| 423 | + input_ids = pad_sequence( |
| 424 | + input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| 425 | + attention_mask = pad_sequence( |
| 426 | + attention_mask, batch_first=True, padding_value=0) |
| 427 | + if loss_scale: |
414 | 428 | loss_scale = pad_sequence( |
415 | 429 | loss_scale, batch_first=True, padding_value=0.) |
| 430 | + labels = pad_sequence(labels, batch_first=True, padding_value=-100) |
| 431 | + |
| 432 | + res = { |
| 433 | + 'input_ids': input_ids, |
| 434 | + 'attention_mask': attention_mask, |
| 435 | + 'labels': labels, |
| 436 | + } |
| 437 | + if loss_scale is not None: |
416 | 438 | res['loss_scale'] = loss_scale |
417 | 439 | return res |
418 | 440 |
|
@@ -579,11 +601,10 @@ def encode( |
579 | 601 | inputs['images'] = image_tensor.to(model.dtype) |
580 | 602 | return inputs, {} |
581 | 603 |
|
582 | | - def data_collator( |
583 | | - self, |
584 | | - batch: List[Dict[str, Any]], |
585 | | - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
586 | | - res = super().data_collator(batch, pad_to_multiple_of) |
| 604 | + def data_collator(self, |
| 605 | + batch: List[Dict[str, Any]], |
| 606 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 607 | + res = super().data_collator(batch, padding_to) |
587 | 608 | res['images'] = torch.concat([b['images'] for b in batch]) |
588 | 609 | return res |
589 | 610 |
|
@@ -887,11 +908,10 @@ def encode( |
887 | 908 | inputs['image_sizes'] = image_sizes |
888 | 909 | return inputs, {} |
889 | 910 |
|
890 | | - def data_collator( |
891 | | - self, |
892 | | - batch: List[Dict[str, Any]], |
893 | | - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
894 | | - res = super().data_collator(batch, pad_to_multiple_of) |
| 911 | + def data_collator(self, |
| 912 | + batch: List[Dict[str, Any]], |
| 913 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 914 | + res = super().data_collator(batch, padding_to) |
895 | 915 | res['images'] = torch.concat([b['images'] for b in batch]) |
896 | 916 | res['image_sizes'] = sum([b['image_sizes'] for b in batch], start=[]) |
897 | 917 | return res |
@@ -1073,11 +1093,10 @@ def encode( |
1073 | 1093 | len(inputs['input_ids']) - len(token_type_ids)) |
1074 | 1094 | return inputs, {} |
1075 | 1095 |
|
1076 | | - def data_collator( |
1077 | | - self, |
1078 | | - batch: List[Dict[str, Any]], |
1079 | | - pad_to_multiple_of: Optional[int] = None) -> Dict[str, Any]: |
1080 | | - res = super().data_collator(batch, pad_to_multiple_of) |
| 1096 | + def data_collator(self, |
| 1097 | + batch: List[Dict[str, Any]], |
| 1098 | + padding_to: Optional[int] = None) -> Dict[str, Any]: |
| 1099 | + res = super().data_collator(batch, padding_to) |
1081 | 1100 | is_cogagent = 'cross_images' in batch[0] |
1082 | 1101 | keys = ['images', 'cross_images'] if is_cogagent else ['images'] |
1083 | 1102 | for key in keys: |
|
0 commit comments