|
14 | 14 | from transformers.modeling_outputs import (
|
15 | 15 | BaseModelOutputWithPastAndCrossAttentions,
|
16 | 16 | CausalLMOutputWithCrossAttentions,
|
| 17 | + CausalLMOutputWithPast, |
17 | 18 | QuestionAnsweringModelOutput,
|
18 | 19 | SequenceClassifierOutputWithPast,
|
19 | 20 | TokenClassifierOutput,
|
|
31 | 32 | from colossalai.pipeline.stage_manager import PipelineStageManager
|
32 | 33 | from colossalai.shardformer.shard import ShardConfig
|
33 | 34 |
|
| 35 | +from ..layer import cross_entropy_1d |
| 36 | + |
34 | 37 |
|
35 | 38 | def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
|
36 | 39 | def build_falcon_alibi_tensor(
|
@@ -437,14 +440,28 @@ def falcon_for_causal_lm_forward(
|
437 | 440 | loss = None
|
438 | 441 | if labels is not None:
|
439 | 442 | # Shift so that tokens < n predict n
|
| 443 | + labels = labels.to(lm_logits.device) |
440 | 444 | shift_logits = lm_logits[..., :-1, :].contiguous()
|
441 | 445 | shift_labels = labels[..., 1:].contiguous()
|
442 | 446 | batch_size, seq_length, vocab_size = shift_logits.shape
|
443 | 447 | # Flatten the tokens
|
444 | 448 | loss_fct = CrossEntropyLoss()
|
445 |
| - loss = loss_fct( |
446 |
| - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) |
447 |
| - ) |
| 449 | + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: |
| 450 | + new_vocab_size = shift_logits.shape[-1] |
| 451 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 452 | + shift_labels = shift_labels.view(-1) |
| 453 | + loss = cross_entropy_1d( |
| 454 | + shift_logits, |
| 455 | + shift_labels, |
| 456 | + process_group=shard_config.tensor_parallel_process_group, |
| 457 | + vocab_size=self.lm_head.out_features, |
| 458 | + dtype=self.transformer.dtype, |
| 459 | + ) |
| 460 | + else: |
| 461 | + loss = loss_fct( |
| 462 | + shift_logits.view(batch_size * seq_length, vocab_size), |
| 463 | + shift_labels.view(batch_size * seq_length), |
| 464 | + ) |
448 | 465 |
|
449 | 466 | if not return_dict:
|
450 | 467 | output = (lm_logits,) + transformer_outputs[1:]
|
@@ -747,3 +764,79 @@ def falcon_for_question_answering_forward(
|
747 | 764 | else:
|
748 | 765 | hidden_states = outputs.get("hidden_states")
|
749 | 766 | return {"hidden_states": hidden_states}
|
| 767 | + |
| 768 | + |
| 769 | +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): |
| 770 | + from transformers import FalconForCausalLM |
| 771 | + |
| 772 | + def forward( |
| 773 | + self: FalconForCausalLM, |
| 774 | + input_ids: Optional[torch.LongTensor] = None, |
| 775 | + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
| 776 | + attention_mask: Optional[torch.Tensor] = None, |
| 777 | + head_mask: Optional[torch.Tensor] = None, |
| 778 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 779 | + labels: Optional[torch.Tensor] = None, |
| 780 | + use_cache: Optional[bool] = None, |
| 781 | + output_attentions: Optional[bool] = None, |
| 782 | + output_hidden_states: Optional[bool] = None, |
| 783 | + return_dict: Optional[bool] = None, |
| 784 | + ) -> Union[Tuple, CausalLMOutputWithPast]: |
| 785 | + r""" |
| 786 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 787 | + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
| 788 | + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
| 789 | + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
| 790 | + """ |
| 791 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 792 | + output_hidden_states = ( |
| 793 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 794 | + ) |
| 795 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 796 | + |
| 797 | + transformer_outputs = self.transformer( |
| 798 | + input_ids, |
| 799 | + past_key_values=past_key_values, |
| 800 | + attention_mask=attention_mask, |
| 801 | + head_mask=head_mask, |
| 802 | + inputs_embeds=inputs_embeds, |
| 803 | + use_cache=use_cache, |
| 804 | + output_attentions=output_attentions, |
| 805 | + output_hidden_states=output_hidden_states, |
| 806 | + return_dict=return_dict, |
| 807 | + ) |
| 808 | + past_key_values = None |
| 809 | + hidden_states = transformer_outputs[0] |
| 810 | + lm_logits = self.lm_head(hidden_states) |
| 811 | + loss = None |
| 812 | + if labels is not None: |
| 813 | + # Shift so that tokens < n predict n |
| 814 | + labels = labels.to(lm_logits.device) |
| 815 | + shift_logits = lm_logits[..., :-1, :].contiguous() |
| 816 | + shift_labels = labels[..., 1:].contiguous() |
| 817 | + batch_size, seq_length, vocab_size = shift_logits.shape |
| 818 | + # Flatten the tokens |
| 819 | + new_vocab_size = shift_logits.shape[-1] |
| 820 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 821 | + shift_labels = shift_labels.view(-1) |
| 822 | + loss = cross_entropy_1d( |
| 823 | + shift_logits, |
| 824 | + shift_labels, |
| 825 | + process_group=shard_config.tensor_parallel_process_group, |
| 826 | + vocab_size=self.lm_head.out_features, |
| 827 | + dtype=self.transformer.dtype, |
| 828 | + ) |
| 829 | + |
| 830 | + if not return_dict: |
| 831 | + output = (lm_logits,) + transformer_outputs[1:] |
| 832 | + return ((loss,) + output) if loss is not None else output |
| 833 | + |
| 834 | + return CausalLMOutputWithPast( |
| 835 | + loss=loss, |
| 836 | + logits=lm_logits, |
| 837 | + past_key_values=transformer_outputs.past_key_values, |
| 838 | + hidden_states=transformer_outputs.hidden_states, |
| 839 | + attentions=transformer_outputs.attentions, |
| 840 | + ) |
| 841 | + |
| 842 | + return forward |
0 commit comments