|
| 1 | +""" |
| 2 | +Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3. |
| 3 | +""" |
| 4 | + |
| 5 | +import sys |
| 6 | +from copy import deepcopy |
| 7 | +from typing import List, Optional, Union |
| 8 | + |
| 9 | +import torch |
| 10 | +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss |
| 11 | +from transformers.modeling_outputs import MoeCausalLMOutputWithPast |
| 12 | +from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func |
| 13 | + |
| 14 | + |
| 15 | +def lce_forward( |
| 16 | + self, |
| 17 | + input_ids: Optional[torch.LongTensor] = None, |
| 18 | + attention_mask: Optional[torch.Tensor] = None, |
| 19 | + position_ids: Optional[torch.LongTensor] = None, |
| 20 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 21 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 22 | + labels: Optional[torch.LongTensor] = None, |
| 23 | + use_cache: Optional[bool] = None, |
| 24 | + output_attentions: Optional[bool] = None, |
| 25 | + output_hidden_states: Optional[bool] = None, |
| 26 | + output_router_logits: Optional[bool] = None, |
| 27 | + cache_position: Optional[torch.LongTensor] = None, |
| 28 | + logits_to_keep: Union[int, torch.Tensor] = 0, |
| 29 | + **kwargs, |
| 30 | +) -> MoeCausalLMOutputWithPast: |
| 31 | + r""" |
| 32 | + Args: |
| 33 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 34 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 35 | + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 36 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| 37 | +
|
| 38 | + logits_to_keep (`int` or `torch.Tensor`, *optional*): |
| 39 | + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
| 40 | + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
| 41 | + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
| 42 | + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
| 43 | + This is useful when using packed tensor format (single dimension for batch and sequence length). |
| 44 | +
|
| 45 | + Returns: |
| 46 | + """ |
| 47 | + |
| 48 | + # pylint: disable=duplicate-code |
| 49 | + output_attentions = ( |
| 50 | + output_attentions |
| 51 | + if output_attentions is not None |
| 52 | + else self.config.output_attentions |
| 53 | + ) |
| 54 | + output_router_logits = ( |
| 55 | + output_router_logits |
| 56 | + if output_router_logits is not None |
| 57 | + else self.config.output_router_logits |
| 58 | + ) |
| 59 | + output_hidden_states = ( |
| 60 | + output_hidden_states |
| 61 | + if output_hidden_states is not None |
| 62 | + else self.config.output_hidden_states |
| 63 | + ) |
| 64 | + |
| 65 | + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) |
| 66 | + outputs = self.model( |
| 67 | + input_ids=input_ids, |
| 68 | + attention_mask=attention_mask, |
| 69 | + position_ids=position_ids, |
| 70 | + past_key_values=past_key_values, |
| 71 | + inputs_embeds=inputs_embeds, |
| 72 | + use_cache=use_cache, |
| 73 | + output_attentions=output_attentions, |
| 74 | + output_hidden_states=output_hidden_states, |
| 75 | + output_router_logits=output_router_logits, |
| 76 | + cache_position=cache_position, |
| 77 | + **kwargs, |
| 78 | + ) |
| 79 | + |
| 80 | + hidden_states = outputs[0] |
| 81 | + |
| 82 | + logits = None |
| 83 | + loss = None |
| 84 | + # if in training mode, don't materialize logits |
| 85 | + if self.training and (labels is not None): |
| 86 | + loss = LigerForCausalLMLoss( |
| 87 | + hidden_states=hidden_states, |
| 88 | + lm_head_weight=self.lm_head.weight, |
| 89 | + labels=labels, |
| 90 | + hidden_size=self.config.hidden_size, |
| 91 | + **kwargs, |
| 92 | + ) |
| 93 | + |
| 94 | + else: # if in inference mode materialize logits |
| 95 | + slice_indices = ( |
| 96 | + slice(-logits_to_keep, None) |
| 97 | + if isinstance(logits_to_keep, int) |
| 98 | + else logits_to_keep |
| 99 | + ) |
| 100 | + logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| 101 | + if labels is not None: |
| 102 | + loss = self.loss_function( |
| 103 | + logits=logits, |
| 104 | + labels=labels, |
| 105 | + vocab_size=self.config.vocab_size, |
| 106 | + **kwargs, |
| 107 | + ) |
| 108 | + |
| 109 | + aux_loss = None |
| 110 | + if output_router_logits: |
| 111 | + aux_loss = load_balancing_loss_func( |
| 112 | + outputs.router_logits, |
| 113 | + self.num_experts, |
| 114 | + self.num_experts_per_tok, |
| 115 | + attention_mask, |
| 116 | + ) |
| 117 | + if labels is not None: |
| 118 | + loss += self.router_aux_loss_coef * aux_loss.to( |
| 119 | + loss.device |
| 120 | + ) # make sure to reside in the same device |
| 121 | + |
| 122 | + return MoeCausalLMOutputWithPast( |
| 123 | + loss=loss, |
| 124 | + aux_loss=aux_loss, |
| 125 | + logits=logits, |
| 126 | + past_key_values=outputs.past_key_values, |
| 127 | + hidden_states=outputs.hidden_states, |
| 128 | + attentions=outputs.attentions, |
| 129 | + ) |
| 130 | + |
| 131 | + |
| 132 | +def apply_liger_kernel_to_qwen3_moe( |
| 133 | + cross_entropy: bool = False, |
| 134 | + fused_linear_cross_entropy: bool = False, |
| 135 | + rms_norm: bool = False, |
| 136 | + glu_activation: bool = False, |
| 137 | + layer_norm: bool = False, |
| 138 | + **kwargs, # pylint: disable=unused-argument |
| 139 | +) -> None: |
| 140 | + # pylint: disable=duplicate-code |
| 141 | + """ |
| 142 | + Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) |
| 143 | +
|
| 144 | + Args: |
| 145 | + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. |
| 146 | + fused_linear_cross_entropy (bool): |
| 147 | + Whether to apply Liger's fused linear cross entropy loss. Default is False. |
| 148 | + `cross_entropy` and `fused_linear_cross_entropy` cannot both be False. |
| 149 | + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. |
| 150 | + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False. |
| 151 | + glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False. |
| 152 | + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. |
| 153 | + """ |
| 154 | + |
| 155 | + import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import |
| 156 | + from liger_kernel.transformers.functional import liger_cross_entropy |
| 157 | + from liger_kernel.transformers.layer_norm import LigerLayerNorm |
| 158 | + from liger_kernel.transformers.rms_norm import LigerRMSNorm |
| 159 | + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP |
| 160 | + |
| 161 | + assert not ( |
| 162 | + cross_entropy and fused_linear_cross_entropy |
| 163 | + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." |
| 164 | + |
| 165 | + modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"] |
| 166 | + |
| 167 | + if rms_norm: |
| 168 | + modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm |
| 169 | + |
| 170 | + if glu_activation: |
| 171 | + |
| 172 | + def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs): |
| 173 | + "Accepts intermediate_size to pass to LigerSwiGLUMLP" |
| 174 | + # clone config to avoid modifying the original |
| 175 | + config = deepcopy(config) |
| 176 | + if intermediate_size: |
| 177 | + setattr(config, "intermediate_size", intermediate_size) |
| 178 | + return LigerSwiGLUMLP(config, **kwargs) |
| 179 | + |
| 180 | + modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper |
| 181 | + |
| 182 | + if layer_norm: |
| 183 | + modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm |
| 184 | + |
| 185 | + if cross_entropy: |
| 186 | + from transformers.loss.loss_utils import nn |
| 187 | + |
| 188 | + nn.functional.cross_entropy = liger_cross_entropy |
| 189 | + |
| 190 | + if fused_linear_cross_entropy: |
| 191 | + modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward |
0 commit comments