|
| 1 | +from typing import List |
| 2 | +from typing import Optional |
| 3 | +from typing import Union |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +from transformers.modeling_outputs import MoeModelOutputWithPast |
| 8 | +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func |
| 9 | + |
| 10 | +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss |
| 11 | +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result |
| 12 | +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast |
| 13 | + |
| 14 | + |
| 15 | +def lce_forward( |
| 16 | + self, |
| 17 | + input_ids: Optional[torch.LongTensor] = None, |
| 18 | + attention_mask: Optional[torch.Tensor] = None, |
| 19 | + position_ids: Optional[torch.LongTensor] = None, |
| 20 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 21 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 22 | + labels: Optional[torch.LongTensor] = None, |
| 23 | + use_cache: Optional[bool] = None, |
| 24 | + output_attentions: Optional[bool] = None, |
| 25 | + output_hidden_states: Optional[bool] = None, |
| 26 | + output_router_logits: Optional[bool] = None, |
| 27 | + cache_position: Optional[torch.LongTensor] = None, |
| 28 | + logits_to_keep: Union[int, torch.Tensor] = 0, |
| 29 | + skip_logits: Optional[bool] = None, |
| 30 | + **kwargs, |
| 31 | +) -> LigerMoeCausalLMOutputWithPast: |
| 32 | + r""" |
| 33 | + Forward pass for causal language modeling with Mixture of Experts (MoE) architecture using Liger Kernel optimizations. |
| 34 | +
|
| 35 | + Args: |
| 36 | + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 37 | + Indices of input sequence tokens in the vocabulary. Indices can be obtained using tokenizers. |
| 38 | + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 39 | + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| 40 | + - 1 for tokens that are **not masked**, |
| 41 | + - 0 for tokens that are **masked**. |
| 42 | + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 43 | + Indices of positions of each input sequence tokens in the position embeddings. |
| 44 | + past_key_values (`List[torch.FloatTensor]` or `Cache`, *optional*): |
| 45 | + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up |
| 46 | + sequential decoding. See `past_key_values` input for more details. |
| 47 | + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| 48 | + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| 49 | + This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| 50 | + than the model's internal embedding lookup matrix. |
| 51 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 52 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 53 | + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 54 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| 55 | + use_cache (`bool`, *optional*): |
| 56 | + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| 57 | + (see `past_key_values`). |
| 58 | + output_attentions (`bool`, *optional*): |
| 59 | + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| 60 | + tensors for more detail. |
| 61 | + output_hidden_states (`bool`, *optional*): |
| 62 | + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| 63 | + more detail. |
| 64 | + output_router_logits (`bool`, *optional*): |
| 65 | + Whether or not to return the router logits of all MoE layers. See `router_logits` under returned tensors |
| 66 | + for more detail. |
| 67 | + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| 68 | + Indices depicting the position of the input sequence tokens in the sequence. |
| 69 | + logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0): |
| 70 | + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
| 71 | + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
| 72 | + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
| 73 | + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
| 74 | + This is useful when using packed tensor format (single dimension for batch and sequence length). |
| 75 | + skip_logits (`bool`, *optional*): |
| 76 | + Whether to skip logit computation and directly compute loss. If `None`, defaults to `True` during training |
| 77 | + when labels are provided (to save memory), and `False` during inference. |
| 78 | +
|
| 79 | + Returns: |
| 80 | + `LigerMoeCausalLMOutputWithPast`: An output object containing: |
| 81 | + - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| 82 | + Language modeling loss (for next-token prediction), including the auxiliary load balancing loss. |
| 83 | + - aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): |
| 84 | + Auxiliary load balancing loss for the sparse MoE modules. |
| 85 | + - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): |
| 86 | + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| 87 | + Note: logits are `None` during training when `skip_logits=True` to save memory. |
| 88 | + - past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed): |
| 89 | + Cached key and value projection states for faster sequential decoding. |
| 90 | + - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): |
| 91 | + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer) of shape |
| 92 | + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer. |
| 93 | + - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): |
| 94 | + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| 95 | + sequence_length)`. Attentions weights after the attention softmax. |
| 96 | + - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True`): |
| 97 | + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. |
| 98 | + Router logits of the MoE layers, useful to compute the auxiliary loss and z_loss. |
| 99 | + - token_accuracy (`torch.FloatTensor`, *optional*, returned when `labels` is provided): |
| 100 | + Token-level prediction accuracy. |
| 101 | +
|
| 102 | + Example: |
| 103 | +
|
| 104 | + ```python |
| 105 | + >>> from transformers import AutoTokenizer, GptOssForCausalLM |
| 106 | + >>> from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss |
| 107 | +
|
| 108 | + >>> # Apply Liger Kernel patches for optimized performance |
| 109 | + >>> apply_liger_kernel_to_gpt_oss() |
| 110 | +
|
| 111 | + >>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b") |
| 112 | + >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") |
| 113 | +
|
| 114 | + >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| 115 | + >>> inputs = tokenizer(prompt, return_tensors="pt") |
| 116 | +
|
| 117 | + >>> # Inference: Forward pass returns logits |
| 118 | + >>> outputs = model(**inputs) |
| 119 | + >>> outputs.logits.shape |
| 120 | + torch.Size([1, 12, 201088]) |
| 121 | +
|
| 122 | + >>> # Get next token prediction |
| 123 | + >>> next_token_logits = outputs.logits[:, -1, :] |
| 124 | + >>> predicted_token_id = next_token_logits.argmax(dim=-1) |
| 125 | +
|
| 126 | + >>> # Training: Forward pass with labels returns loss |
| 127 | + >>> labels = inputs.input_ids.clone() |
| 128 | + >>> outputs = model(**inputs, labels=labels) |
| 129 | + >>> outputs.loss |
| 130 | + tensor(2.6454) |
| 131 | + ```""" |
| 132 | + |
| 133 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 134 | + output_router_logits = ( |
| 135 | + output_router_logits if output_router_logits is not None else self.config.output_router_logits |
| 136 | + ) |
| 137 | + |
| 138 | + output_hidden_states = ( |
| 139 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 140 | + ) |
| 141 | + |
| 142 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 143 | + outputs: MoeModelOutputWithPast = self.model( |
| 144 | + input_ids=input_ids, |
| 145 | + attention_mask=attention_mask, |
| 146 | + position_ids=position_ids, |
| 147 | + past_key_values=past_key_values, |
| 148 | + inputs_embeds=inputs_embeds, |
| 149 | + use_cache=use_cache, |
| 150 | + output_attentions=output_attentions, |
| 151 | + output_hidden_states=output_hidden_states, |
| 152 | + output_router_logits=output_router_logits, |
| 153 | + cache_position=cache_position, |
| 154 | + **kwargs, |
| 155 | + ) |
| 156 | + |
| 157 | + hidden_states = outputs.last_hidden_state |
| 158 | + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss |
| 159 | + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| 160 | + kept_hidden_states = hidden_states[:, slice_indices, :] |
| 161 | + |
| 162 | + shift_labels = kwargs.pop("shift_labels", None) |
| 163 | + logits = None |
| 164 | + loss = None |
| 165 | + token_accuracy = None |
| 166 | + |
| 167 | + if skip_logits is None: |
| 168 | + skip_logits = self.training and (labels is not None or shift_labels is not None) |
| 169 | + |
| 170 | + if skip_logits: |
| 171 | + result = LigerForCausalLMLoss( |
| 172 | + hidden_states=kept_hidden_states, |
| 173 | + lm_head_weight=self.lm_head.weight, |
| 174 | + labels=labels, |
| 175 | + shift_labels=shift_labels, |
| 176 | + hidden_size=self.config.hidden_size, |
| 177 | + **kwargs, |
| 178 | + ) |
| 179 | + loss, _, token_accuracy = unpack_cross_entropy_result(result) |
| 180 | + else: # if in inference model materialize logits |
| 181 | + logits = self.lm_head(kept_hidden_states) |
| 182 | + if labels is not None or shift_labels is not None: |
| 183 | + loss = self.loss_function( |
| 184 | + logits=logits, |
| 185 | + labels=labels, |
| 186 | + shift_labels=shift_labels, |
| 187 | + vocab_size=self.vocab_size, |
| 188 | + **kwargs, |
| 189 | + ) |
| 190 | + |
| 191 | + aux_loss = None |
| 192 | + if output_router_logits: |
| 193 | + aux_loss = load_balancing_loss_func( |
| 194 | + outputs.router_logits, |
| 195 | + self.num_experts, |
| 196 | + self.num_experts_per_tok, |
| 197 | + attention_mask, |
| 198 | + ) |
| 199 | + if labels is not None: |
| 200 | + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device |
| 201 | + |
| 202 | + return LigerMoeCausalLMOutputWithPast( |
| 203 | + loss=loss, |
| 204 | + aux_loss=aux_loss, |
| 205 | + logits=logits, |
| 206 | + past_key_values=outputs.past_key_values, |
| 207 | + hidden_states=outputs.hidden_states, |
| 208 | + attentions=outputs.attentions, |
| 209 | + router_logits=outputs.router_logits, |
| 210 | + token_accuracy=token_accuracy, |
| 211 | + ) |
0 commit comments