|
| 1 | +from typing import List, Optional, Tuple, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +import torch.distributed as dist |
| 6 | +from torch.distributed import ProcessGroup |
| 7 | +from transformers.cache_utils import Cache, DynamicCache |
| 8 | +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask |
| 9 | +from transformers.modeling_outputs import BaseModelOutputWithPast |
| 10 | + |
| 11 | +from colossalai.lazy import LazyInitContext |
| 12 | +from colossalai.moe._operation import ( |
| 13 | + DPGradScalerIn, |
| 14 | + DPGradScalerOut, |
| 15 | + EPGradScalerIn, |
| 16 | + EPGradScalerOut, |
| 17 | + all_to_all_uneven, |
| 18 | +) |
| 19 | +from colossalai.shardformer.layer.linear import ParallelModule |
| 20 | +from colossalai.shardformer.shard.utils import set_tensors_to_none |
| 21 | +from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group |
| 22 | + |
| 23 | + |
| 24 | +class EpDeepseekV3MoE(ParallelModule): |
| 25 | + """ |
| 26 | + A mixed expert module containing shared experts. |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, config): |
| 30 | + raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") |
| 31 | + |
| 32 | + def setup_process_groups( |
| 33 | + self, |
| 34 | + moe_dp_group: ProcessGroup, |
| 35 | + ep_group: ProcessGroup, |
| 36 | + ): |
| 37 | + assert moe_dp_group is not None |
| 38 | + assert ep_group is not None |
| 39 | + |
| 40 | + self.ep_size = dist.get_world_size(ep_group) |
| 41 | + self.ep_rank = dist.get_rank(ep_group) |
| 42 | + self.num_experts = self.config.n_routed_experts |
| 43 | + assert self.num_experts % self.ep_size == 0 |
| 44 | + |
| 45 | + self.ep_group = ep_group |
| 46 | + self.num_experts_per_ep = self.num_experts // self.ep_size |
| 47 | + self.experts_per_rank = self.num_experts_per_ep |
| 48 | + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep |
| 49 | + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] |
| 50 | + |
| 51 | + set_tensors_to_none(self.experts, exclude=set(held_experts)) |
| 52 | + |
| 53 | + # setup moe_dp group |
| 54 | + self.moe_dp_group = moe_dp_group |
| 55 | + self.moe_dp_size = dist.get_world_size(moe_dp_group) |
| 56 | + |
| 57 | + for p in self.experts.parameters(): |
| 58 | + set_moe_tensor_ep_group(p, ep_group) |
| 59 | + |
| 60 | + @staticmethod |
| 61 | + def from_native_module( |
| 62 | + module, |
| 63 | + moe_dp_group: ProcessGroup, |
| 64 | + ep_group: ProcessGroup, |
| 65 | + *args, |
| 66 | + **kwargs, |
| 67 | + ) -> "EpDeepseekV3MoE": |
| 68 | + if module.__class__.__name__ != "DeepseekV3MLP": |
| 69 | + module.__class__ = EpDeepseekV3MoE |
| 70 | + module.setup_process_groups(moe_dp_group, ep_group) |
| 71 | + LazyInitContext.materialize(module) |
| 72 | + return module |
| 73 | + |
| 74 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 75 | + identity = hidden_states |
| 76 | + orig_shape = hidden_states.shape |
| 77 | + topk_idx, topk_weight = self.gate(hidden_states) |
| 78 | + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| 79 | + y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape) |
| 80 | + if self.config.n_shared_experts is not None: |
| 81 | + y = y + self.shared_experts(identity) |
| 82 | + return y |
| 83 | + |
| 84 | + def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: |
| 85 | + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) |
| 86 | + cnts.scatter_(1, topk_ids, 1) |
| 87 | + tokens_per_expert = cnts.sum(dim=0) |
| 88 | + idxs = topk_ids.view(-1).argsort() |
| 89 | + sorted_tokens = x[idxs // topk_ids.shape[1]] |
| 90 | + if self.ep_size > 1: |
| 91 | + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) |
| 92 | + tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) |
| 93 | + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group) |
| 94 | + |
| 95 | + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist() |
| 96 | + input_split_sizes = tokens_per_ep_rank.tolist() |
| 97 | + |
| 98 | + gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group) |
| 99 | + tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0) |
| 100 | + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) |
| 101 | + s = 0 |
| 102 | + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): |
| 103 | + gatherd_idxs[s : s + k] = i % self.experts_per_rank |
| 104 | + s += k |
| 105 | + gatherd_idxs = gatherd_idxs.argsort() |
| 106 | + sorted_tokens = gathered_tokens[gatherd_idxs] |
| 107 | + tokens_per_expert = tokens_per_expert_post_gather |
| 108 | + |
| 109 | + # moe-dp related code |
| 110 | + activate_experts = tokens_per_expert_post_gather > 0 |
| 111 | + activate_experts = activate_experts.int() |
| 112 | + dist.all_reduce(activate_experts, group=self.moe_dp_group) |
| 113 | + |
| 114 | + # ep related code |
| 115 | + sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size) |
| 116 | + |
| 117 | + tokens_per_expert = tokens_per_expert.cpu().numpy() |
| 118 | + |
| 119 | + outputs = [] |
| 120 | + start_idx = 0 |
| 121 | + for i, num_tokens in enumerate(tokens_per_expert): |
| 122 | + end_idx = start_idx + num_tokens |
| 123 | + if num_tokens == 0: |
| 124 | + continue |
| 125 | + expert = self.experts[i + self.ep_rank * self.experts_per_rank] |
| 126 | + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] |
| 127 | + # moe-dp related code |
| 128 | + tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i]) |
| 129 | + expert_out = expert(tokens_for_this_expert) |
| 130 | + # moe-dp related code |
| 131 | + expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i]) |
| 132 | + outputs.append(expert_out) |
| 133 | + start_idx = end_idx |
| 134 | + |
| 135 | + if len(outputs) > 0: |
| 136 | + outs = torch.cat(outputs, dim=0) |
| 137 | + else: |
| 138 | + assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}" |
| 139 | + outs = sorted_tokens |
| 140 | + |
| 141 | + if self.ep_size > 1: |
| 142 | + outs = EPGradScalerOut.apply(outs, self.ep_size) |
| 143 | + new_x = torch.empty_like(outs) |
| 144 | + new_x[gatherd_idxs] = outs |
| 145 | + gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group) |
| 146 | + outs = gathered_tokens |
| 147 | + |
| 148 | + new_x = torch.empty_like(outs) |
| 149 | + new_x[idxs] = outs |
| 150 | + final_out = ( |
| 151 | + (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1)) |
| 152 | + .sum(dim=1) |
| 153 | + .type(new_x.dtype) |
| 154 | + ) |
| 155 | + |
| 156 | + return final_out |
| 157 | + |
| 158 | + |
| 159 | +def deepseek_v3_model_forward( |
| 160 | + self, |
| 161 | + input_ids: torch.LongTensor = None, |
| 162 | + attention_mask: Optional[torch.Tensor] = None, |
| 163 | + position_ids: Optional[torch.LongTensor] = None, |
| 164 | + past_key_values: Optional[List[torch.FloatTensor]] = None, |
| 165 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 166 | + use_cache: Optional[bool] = None, |
| 167 | + output_attentions: Optional[bool] = None, |
| 168 | + output_hidden_states: Optional[bool] = None, |
| 169 | + return_dict: Optional[bool] = None, |
| 170 | +) -> Union[Tuple, BaseModelOutputWithPast]: |
| 171 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 172 | + output_hidden_states = ( |
| 173 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 174 | + ) |
| 175 | + use_cache = use_cache if use_cache is not None else self.config.use_cache |
| 176 | + |
| 177 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 178 | + |
| 179 | + # retrieve input_ids and inputs_embeds |
| 180 | + if input_ids is not None and inputs_embeds is not None: |
| 181 | + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| 182 | + elif input_ids is not None: |
| 183 | + batch_size, seq_length = input_ids.shape[:2] |
| 184 | + elif inputs_embeds is not None: |
| 185 | + batch_size, seq_length = inputs_embeds.shape[:2] |
| 186 | + else: |
| 187 | + raise ValueError("You have to specify either input_ids or inputs_embeds") |
| 188 | + |
| 189 | + past_key_values_length = 0 |
| 190 | + if use_cache: |
| 191 | + use_legacy_cache = not isinstance(past_key_values, Cache) |
| 192 | + if use_legacy_cache: |
| 193 | + past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| 194 | + past_key_values_length = past_key_values.get_usable_length(seq_length) |
| 195 | + |
| 196 | + if position_ids is None: |
| 197 | + device = input_ids.device if input_ids is not None else inputs_embeds.device |
| 198 | + position_ids = torch.arange( |
| 199 | + past_key_values_length, |
| 200 | + seq_length + past_key_values_length, |
| 201 | + dtype=torch.long, |
| 202 | + device=device, |
| 203 | + ) |
| 204 | + position_ids = position_ids.unsqueeze(0) |
| 205 | + |
| 206 | + if inputs_embeds is None: |
| 207 | + inputs_embeds = self.embed_tokens(input_ids) |
| 208 | + |
| 209 | + if self._use_flash_attention_2: |
| 210 | + # 2d mask is passed through the layers |
| 211 | + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| 212 | + else: |
| 213 | + # 4d mask is passed through the layers |
| 214 | + attention_mask = _prepare_4d_causal_attention_mask( |
| 215 | + attention_mask, |
| 216 | + (batch_size, seq_length), |
| 217 | + inputs_embeds, |
| 218 | + past_key_values_length, |
| 219 | + ) |
| 220 | + |
| 221 | + # embed positions |
| 222 | + hidden_states = inputs_embeds |
| 223 | + |
| 224 | + # decoder layers |
| 225 | + all_hidden_states = () if output_hidden_states else None |
| 226 | + all_self_attns = () if output_attentions else None |
| 227 | + next_decoder_cache = None |
| 228 | + |
| 229 | + for i, decoder_layer in enumerate(self.layers): |
| 230 | + if output_hidden_states: |
| 231 | + all_hidden_states += (hidden_states,) |
| 232 | + |
| 233 | + if self.gradient_checkpointing and i > 0: |
| 234 | + layer_outputs = self._gradient_checkpointing_func( |
| 235 | + decoder_layer.__call__, |
| 236 | + hidden_states, |
| 237 | + attention_mask, |
| 238 | + position_ids, |
| 239 | + past_key_values, |
| 240 | + output_attentions, |
| 241 | + use_cache, |
| 242 | + ) |
| 243 | + else: |
| 244 | + layer_outputs = decoder_layer( |
| 245 | + hidden_states, |
| 246 | + attention_mask=attention_mask, |
| 247 | + position_ids=position_ids, |
| 248 | + past_key_value=past_key_values, |
| 249 | + output_attentions=output_attentions, |
| 250 | + use_cache=use_cache, |
| 251 | + ) |
| 252 | + |
| 253 | + hidden_states = layer_outputs[0] |
| 254 | + |
| 255 | + if use_cache: |
| 256 | + next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
| 257 | + |
| 258 | + if output_attentions: |
| 259 | + all_self_attns += (layer_outputs[1],) |
| 260 | + |
| 261 | + hidden_states = self.norm(hidden_states) |
| 262 | + |
| 263 | + # add hidden states from the last decoder layer |
| 264 | + if output_hidden_states: |
| 265 | + all_hidden_states += (hidden_states,) |
| 266 | + |
| 267 | + next_cache = None |
| 268 | + if use_cache: |
| 269 | + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
| 270 | + if not return_dict: |
| 271 | + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| 272 | + return BaseModelOutputWithPast( |
| 273 | + last_hidden_state=hidden_states, |
| 274 | + past_key_values=next_cache, |
| 275 | + hidden_states=all_hidden_states, |
| 276 | + attentions=all_self_attns, |
| 277 | + ) |
0 commit comments