3030import torch .nn .functional as F
3131from torch .nn import CrossEntropyLoss
3232from transformers .modeling_outputs import CausalLMOutputWithPast
33- from transformers .models .llama .modeling_llama import (
34- _CONFIG_FOR_DOC ,
35- LLAMA_INPUTS_DOCSTRING ,
36- )
37- from transformers .models .mixtral .modeling_mixtral import (
38- _CONFIG_FOR_DOC ,
39- MIXTRAL_INPUTS_DOCSTRING ,
40- )
41- from transformers .modeling_outputs import (
42- MoeCausalLMOutputWithPast ,
43- MoeModelOutputWithPast ,
44- )
45- from transformers .utils import (
46- add_start_docstrings_to_model_forward ,
47- replace_return_docstrings ,
48- )
4933
5034from .cross_entropy import (
5135 element_mul_kernel ,
@@ -297,11 +281,6 @@ def forward(self, lin_weight, _input, target, bias=None):
297281 self .reduction ,
298282 )
299283
300- # TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models?
301- # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
302- @replace_return_docstrings (
303- output_type = CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC
304- )
305284def lce_forward (
306285 self ,
307286 input_ids : torch .LongTensor = None ,
@@ -435,143 +414,4 @@ def lce_forward(
435414 past_key_values = outputs .past_key_values ,
436415 hidden_states = outputs .hidden_states ,
437416 attentions = outputs .attentions ,
438- )
439-
440- # TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward?
441- @add_start_docstrings_to_model_forward (MIXTRAL_INPUTS_DOCSTRING )
442- @replace_return_docstrings (output_type = MoeCausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
443- # Ignore copy
444- def lce_forward_mixtral (
445- self ,
446- input_ids : torch .LongTensor = None ,
447- attention_mask : Optional [torch .Tensor ] = None ,
448- position_ids : Optional [torch .LongTensor ] = None ,
449- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
450- inputs_embeds : Optional [torch .FloatTensor ] = None ,
451- labels : Optional [torch .LongTensor ] = None ,
452- use_cache : Optional [bool ] = None ,
453- output_attentions : Optional [bool ] = None ,
454- output_hidden_states : Optional [bool ] = None ,
455- output_router_logits : Optional [bool ] = None ,
456- return_dict : Optional [bool ] = None ,
457- cache_position : Optional [torch .LongTensor ] = None ,
458- num_logits_to_keep : int = 0 ,
459- ) -> Union [Tuple , MoeCausalLMOutputWithPast ]:
460- r"""
461- Args:
462- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
463- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
464- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
465- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
466-
467- num_logits_to_keep (`int`, *optional*):
468- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
469- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
470- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
471-
472- Returns:
473-
474- Example:
475-
476- ```python
477- >>> from transformers import AutoTokenizer, MixtralForCausalLM
478-
479- >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
480- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
481-
482- >>> prompt = "Hey, are you conscious? Can you talk to me?"
483- >>> inputs = tokenizer(prompt, return_tensors="pt")
484-
485- >>> # Generate
486- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
487- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
488- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
489- ```"""
490-
491- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
492- output_router_logits = (
493- output_router_logits if output_router_logits is not None else self .config .output_router_logits
494- )
495-
496- output_hidden_states = (
497- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
498- )
499- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
500-
501- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
502- outputs = self .model (
503- input_ids = input_ids ,
504- attention_mask = attention_mask ,
505- position_ids = position_ids ,
506- past_key_values = past_key_values ,
507- inputs_embeds = inputs_embeds ,
508- use_cache = use_cache ,
509- output_attentions = output_attentions ,
510- output_hidden_states = output_hidden_states ,
511- output_router_logits = output_router_logits ,
512- return_dict = return_dict ,
513- cache_position = cache_position ,
514- )
515-
516- hidden_states = outputs [0 ]
517-
518- loss = None
519- logits = None
520-
521- # patch change
522- if self .training and (labels is not None ):
523- shift_hidden_states = hidden_states [..., :- 1 , :].contiguous ()
524- shift_labels = labels [..., 1 :].contiguous ()
525-
526- # flatten tokens
527- shift_hidden_states = shift_hidden_states .view (- 1 , self .config .hidden_size )
528- shift_labels = shift_labels .view (- 1 )
529-
530- lce = LigerFusedLinearCrossEntropyLoss ()
531- loss = lce (self .lm_head .weight , shift_hidden_states , shift_labels )
532- else :
533- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
534- logits = self .lm_head (hidden_states [:, - num_logits_to_keep :, :])
535-
536- if labels is not None :
537- # Upcast to float if we need to compute the loss to avoid potential precision issues
538- logits = logits .float ()
539- # Shift so that tokens < n predict n
540- shift_logits = logits [..., :- 1 , :].contiguous ()
541- shift_labels = labels [..., 1 :].contiguous ()
542- # Flatten the tokens
543- loss_fct = CrossEntropyLoss ()
544- shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
545- shift_labels = shift_labels .view (- 1 )
546- # Enable model parallelism
547- shift_labels = shift_labels .to (shift_logits .device )
548- loss = loss_fct (shift_logits , shift_labels )
549-
550- # TODO: unique differing part to mixtral model forward
551- aux_loss = None
552- if output_router_logits :
553- aux_loss = load_balancing_loss_func (
554- outputs .router_logits if return_dict else outputs [- 1 ],
555- self .num_experts ,
556- self .num_experts_per_tok ,
557- attention_mask ,
558- )
559- # TODO: should this loss manipulation be indented in?? or should it be added to even the liger loss?
560- if labels is not None :
561- loss += self .router_aux_loss_coef * aux_loss .to (loss .device ) # make sure to reside in the same device
562-
563- if not return_dict :
564- output = (logits ,) + outputs [1 :]
565- if output_router_logits :
566- output = (aux_loss ,) + output
567- return (loss ,) + output if loss is not None else output
568-
569- return MoeCausalLMOutputWithPast (
570- loss = loss ,
571- aux_loss = aux_loss ,
572- logits = logits ,
573- past_key_values = outputs .past_key_values ,
574- hidden_states = outputs .hidden_states ,
575- attentions = outputs .attentions ,
576- router_logits = outputs .router_logits ,
577417 )
0 commit comments