|
| 1 | +import warnings |
| 2 | +from typing import List, Optional, Tuple, Union |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
| 7 | +import torch.functional as F |
| 8 | +from torch.distributed import ProcessGroup |
| 9 | +from torch.nn import CrossEntropyLoss |
| 10 | +from transformers.cache_utils import Cache, DynamicCache |
| 11 | +from transformers.modeling_attn_mask_utils import ( |
| 12 | + _prepare_4d_causal_attention_mask, |
| 13 | + _prepare_4d_causal_attention_mask_for_sdpa, |
| 14 | +) |
| 15 | +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| 16 | +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb |
| 17 | +from transformers.utils import is_flash_attn_2_available, logging |
| 18 | + |
| 19 | +from colossalai.lazy import LazyInitContext |
| 20 | +from colossalai.moe._operation import ( |
| 21 | + DPGradScalerIn, |
| 22 | + DPGradScalerOut, |
| 23 | + EPGradScalerIn, |
| 24 | + EPGradScalerOut, |
| 25 | + all_to_all_uneven, |
| 26 | +) |
| 27 | +from colossalai.pipeline.stage_manager import PipelineStageManager |
| 28 | +from colossalai.quantization.fp8 import all_reduce_fp8 |
| 29 | +from colossalai.shardformer.layer._operation import ( |
| 30 | + all_to_all_comm, |
| 31 | + gather_forward_split_backward, |
| 32 | + linear_with_async_comm, |
| 33 | + split_forward_gather_backward, |
| 34 | +) |
| 35 | +from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule |
| 36 | +from colossalai.shardformer.shard import ShardConfig |
| 37 | +from colossalai.shardformer.shard.utils import set_tensors_to_none |
| 38 | +from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param |
| 39 | +from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group |
| 40 | + |
| 41 | + |
| 42 | +class EpDeepseekV3MoE(ParallelModule): |
| 43 | + """ |
| 44 | + A mixed expert module containing shared experts. |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__(self, config): |
| 48 | + raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") |
| 49 | + |
| 50 | + def setup_process_groups( |
| 51 | + self, |
| 52 | + moe_dp_group: ProcessGroup, |
| 53 | + ep_group: ProcessGroup, |
| 54 | + ): |
| 55 | + assert moe_dp_group is not None |
| 56 | + assert ep_group is not None |
| 57 | + |
| 58 | + self.ep_size = dist.get_world_size(ep_group) |
| 59 | + self.ep_rank = dist.get_rank(ep_group) |
| 60 | + self.num_experts = self.config.n_routed_experts |
| 61 | + assert self.num_experts % self.ep_size == 0 |
| 62 | + |
| 63 | + self.ep_group = ep_group |
| 64 | + self.num_experts_per_ep = self.num_experts // self.ep_size |
| 65 | + self.experts_per_rank = self.num_experts_per_ep |
| 66 | + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep |
| 67 | + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] |
| 68 | + |
| 69 | + set_tensors_to_none(self.experts, exclude=set(held_experts)) |
| 70 | + |
| 71 | + # setup moe_dp group |
| 72 | + self.moe_dp_group = moe_dp_group |
| 73 | + self.moe_dp_size = dist.get_world_size(moe_dp_group) |
| 74 | + |
| 75 | + for p in self.experts.parameters(): |
| 76 | + set_moe_tensor_ep_group(p, ep_group) |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def from_native_module( |
| 80 | + module, |
| 81 | + moe_dp_group: ProcessGroup, |
| 82 | + ep_group: ProcessGroup, |
| 83 | + *args, |
| 84 | + **kwargs, |
| 85 | + ) -> "EpDeepseekV3MoE": |
| 86 | + LazyInitContext.materialize(module) |
| 87 | + if module.__class__.__name__ == "DeepseekV3MLP": |
| 88 | + return module |
| 89 | + module.__class__ = EpDeepseekV3MoE |
| 90 | + module.setup_process_groups(moe_dp_group, ep_group) |
| 91 | + return module |
| 92 | + |
| 93 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 94 | + identity = hidden_states |
| 95 | + orig_shape = hidden_states.shape |
| 96 | + topk_idx, topk_weight = self.gate(hidden_states) |
| 97 | + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| 98 | + y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape) |
| 99 | + if self.config.n_shared_experts is not None: |
| 100 | + y = y + self.shared_experts(identity) |
| 101 | + return y |
| 102 | + |
| 103 | + def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: |
| 104 | + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) |
| 105 | + cnts.scatter_(1, topk_ids, 1) |
| 106 | + tokens_per_expert = cnts.sum(dim=0) |
| 107 | + idxs = topk_ids.view(-1).argsort() |
| 108 | + sorted_tokens = x[idxs // topk_ids.shape[1]] |
| 109 | + if self.ep_size > 1: |
| 110 | + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) |
| 111 | + tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) |
| 112 | + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group) |
| 113 | + |
| 114 | + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist() |
| 115 | + input_split_sizes = tokens_per_ep_rank.tolist() |
| 116 | + |
| 117 | + gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group) |
| 118 | + tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0) |
| 119 | + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) |
| 120 | + s = 0 |
| 121 | + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): |
| 122 | + gatherd_idxs[s : s + k] = i % self.experts_per_rank |
| 123 | + s += k |
| 124 | + gatherd_idxs = gatherd_idxs.argsort() |
| 125 | + sorted_tokens = gathered_tokens[gatherd_idxs] |
| 126 | + tokens_per_expert = tokens_per_expert_post_gather |
| 127 | + |
| 128 | + # moe-dp related code |
| 129 | + activate_experts = tokens_per_expert_post_gather > 0 |
| 130 | + activate_experts = activate_experts.int() |
| 131 | + dist.all_reduce(activate_experts, group=self.moe_dp_group) |
| 132 | + |
| 133 | + # ep related code |
| 134 | + sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size) |
| 135 | + |
| 136 | + tokens_per_expert = tokens_per_expert.cpu().numpy() |
| 137 | + |
| 138 | + outputs = [] |
| 139 | + start_idx = 0 |
| 140 | + for i, num_tokens in enumerate(tokens_per_expert): |
| 141 | + end_idx = start_idx + num_tokens |
| 142 | + if num_tokens == 0: |
| 143 | + continue |
| 144 | + expert = self.experts[i + self.ep_rank * self.experts_per_rank] |
| 145 | + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] |
| 146 | + # moe-dp related code |
| 147 | + tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i]) |
| 148 | + expert_out = expert(tokens_for_this_expert) |
| 149 | + # moe-dp related code |
| 150 | + expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i]) |
| 151 | + outputs.append(expert_out) |
| 152 | + start_idx = end_idx |
| 153 | + |
| 154 | + if len(outputs) > 0: |
| 155 | + outs = torch.cat(outputs, dim=0) |
| 156 | + else: |
| 157 | + assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}" |
| 158 | + outs = sorted_tokens |
| 159 | + |
| 160 | + if self.ep_size > 1: |
| 161 | + outs = EPGradScalerOut.apply(outs, self.ep_size) |
| 162 | + new_x = torch.empty_like(outs) |
| 163 | + new_x[gatherd_idxs] = outs |
| 164 | + gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group) |
| 165 | + outs = gathered_tokens |
| 166 | + |
| 167 | + new_x = torch.empty_like(outs) |
| 168 | + new_x[idxs] = outs |
| 169 | + final_out = ( |
| 170 | + (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1)) |
| 171 | + .sum(dim=1) |
| 172 | + .type(new_x.dtype) |
| 173 | + ) |
| 174 | + |
| 175 | + return final_out |
0 commit comments