|
16 | 16 | from colossalai.pipeline.stage_manager import PipelineStageManager
|
17 | 17 | from colossalai.shardformer.shard import ShardConfig
|
18 | 18 |
|
19 |
| -from ..layer import ColoAttention |
| 19 | +from ..layer import ColoAttention, cross_entropy_1d |
20 | 20 |
|
21 | 21 | logger = logging.get_logger(__name__)
|
22 | 22 |
|
@@ -270,11 +270,22 @@ def mistral_for_causal_lm_forward(
|
270 | 270 | shift_labels = labels[..., 1:].contiguous()
|
271 | 271 | # Flatten the tokens
|
272 | 272 | loss_fct = CrossEntropyLoss()
|
273 |
| - shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| 273 | + #shift_logits = shift_logits.view(-1, self.config.vocab_size) |
274 | 274 | shift_labels = shift_labels.view(-1)
|
275 | 275 | # Enable model parallelism
|
276 | 276 | shift_labels = shift_labels.to(shift_logits.device)
|
277 |
| - loss = loss_fct(shift_logits, shift_labels) |
| 277 | + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: |
| 278 | + new_vocab_size = logits.shape[-1] |
| 279 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 280 | + loss = cross_entropy_1d( |
| 281 | + shift_logits, |
| 282 | + shift_labels, |
| 283 | + process_group=shard_config.tensor_parallel_process_group, |
| 284 | + vocab_size=self.lm_head.out_features, |
| 285 | + ) |
| 286 | + else: |
| 287 | + shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| 288 | + loss = loss_fct(shift_logits, shift_labels) |
278 | 289 |
|
279 | 290 | if not return_dict:
|
280 | 291 | output = (logits,) + outputs[1:]
|
@@ -609,3 +620,105 @@ def forward(
|
609 | 620 | return attn_output, None, past_key_value
|
610 | 621 |
|
611 | 622 | return forward
|
| 623 | + |
| 624 | + |
| 625 | +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): |
| 626 | + from transformers import MistralForCausalLM |
| 627 | + |
| 628 | + def forward( |
| 629 | + self: MistralForCausalLM, |
| 630 | + input_ids: torch.LongTensor = None, |
| 631 | + attention_mask: Optional[torch.Tensor] = None, |
| 632 | + position_ids: Optional[torch.LongTensor] = None, |
| 633 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 634 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 635 | + labels: Optional[torch.LongTensor] = None, |
| 636 | + use_cache: Optional[bool] = None, |
| 637 | + output_attentions: Optional[bool] = None, |
| 638 | + output_hidden_states: Optional[bool] = None, |
| 639 | + return_dict: Optional[bool] = None, |
| 640 | + ) -> Union[Tuple, CausalLMOutputWithPast]: |
| 641 | + r""" |
| 642 | + Args: |
| 643 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 644 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 645 | + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 646 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| 647 | +
|
| 648 | + Returns: |
| 649 | +
|
| 650 | + Example: |
| 651 | +
|
| 652 | + ```python |
| 653 | + >>> from transformers import AutoTokenizer, MistralForCausalLM |
| 654 | +
|
| 655 | + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
| 656 | + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
| 657 | +
|
| 658 | + >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| 659 | + >>> inputs = tokenizer(prompt, return_tensors="pt") |
| 660 | +
|
| 661 | + >>> # Generate |
| 662 | + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| 663 | + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| 664 | + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| 665 | + ```""" |
| 666 | + |
| 667 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 668 | + output_hidden_states = ( |
| 669 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 670 | + ) |
| 671 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 672 | + |
| 673 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 674 | + outputs = self.model( |
| 675 | + input_ids=input_ids, |
| 676 | + attention_mask=attention_mask, |
| 677 | + position_ids=position_ids, |
| 678 | + past_key_values=past_key_values, |
| 679 | + inputs_embeds=inputs_embeds, |
| 680 | + use_cache=use_cache, |
| 681 | + output_attentions=output_attentions, |
| 682 | + output_hidden_states=output_hidden_states, |
| 683 | + return_dict=return_dict, |
| 684 | + ) |
| 685 | + |
| 686 | + hidden_states = outputs[0] |
| 687 | + if self.config.pretraining_tp > 1: |
| 688 | + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
| 689 | + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
| 690 | + logits = torch.cat(logits, dim=-1) |
| 691 | + else: |
| 692 | + logits = self.lm_head(hidden_states) |
| 693 | + logits = logits.float() |
| 694 | + |
| 695 | + loss = None |
| 696 | + if labels is not None: |
| 697 | + # Shift so that tokens < n predict n |
| 698 | + shift_logits = logits[..., :-1, :].contiguous() |
| 699 | + shift_labels = labels[..., 1:].contiguous() |
| 700 | + shift_labels = shift_labels.view(-1) |
| 701 | + # Enable model parallelism |
| 702 | + shift_labels = shift_labels.to(shift_logits.device) |
| 703 | + new_vocab_size = logits.shape[-1] |
| 704 | + shift_logits = shift_logits.view(-1, new_vocab_size) |
| 705 | + loss = cross_entropy_1d( |
| 706 | + shift_logits, |
| 707 | + shift_labels, |
| 708 | + process_group=shard_config.tensor_parallel_process_group, |
| 709 | + vocab_size=self.lm_head.out_features, |
| 710 | + ) |
| 711 | + |
| 712 | + if not return_dict: |
| 713 | + output = (logits,) + outputs[1:] |
| 714 | + return (loss,) + output if loss is not None else output |
| 715 | + |
| 716 | + return CausalLMOutputWithPast( |
| 717 | + loss=loss, |
| 718 | + logits=logits, |
| 719 | + past_key_values=outputs.past_key_values, |
| 720 | + hidden_states=outputs.hidden_states, |
| 721 | + attentions=outputs.attentions, |
| 722 | + ) |
| 723 | + |
| 724 | + return forward |
0 commit comments