Skip to content

Commit d96db80

Browse files
committed
[feature] support ep for deepseek v3
1 parent ca0aa23 commit d96db80

File tree

7 files changed

+461
-2
lines changed

7 files changed

+461
-2
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import warnings
2+
from typing import List, Optional, Tuple, Union
3+
4+
import numpy as np
5+
import torch
6+
import torch.distributed as dist
7+
import torch.functional as F
8+
from torch.distributed import ProcessGroup
9+
from torch.nn import CrossEntropyLoss
10+
from transformers.cache_utils import Cache, DynamicCache
11+
from transformers.modeling_attn_mask_utils import (
12+
_prepare_4d_causal_attention_mask,
13+
_prepare_4d_causal_attention_mask_for_sdpa,
14+
)
15+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
17+
from transformers.utils import is_flash_attn_2_available, logging
18+
19+
from colossalai.lazy import LazyInitContext
20+
from colossalai.moe._operation import (
21+
DPGradScalerIn,
22+
DPGradScalerOut,
23+
EPGradScalerIn,
24+
EPGradScalerOut,
25+
all_to_all_uneven,
26+
)
27+
from colossalai.pipeline.stage_manager import PipelineStageManager
28+
from colossalai.quantization.fp8 import all_reduce_fp8
29+
from colossalai.shardformer.layer._operation import (
30+
all_to_all_comm,
31+
gather_forward_split_backward,
32+
linear_with_async_comm,
33+
split_forward_gather_backward,
34+
)
35+
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
36+
from colossalai.shardformer.shard import ShardConfig
37+
from colossalai.shardformer.shard.utils import set_tensors_to_none
38+
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
39+
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
40+
41+
42+
class EpDeepseekV3MoE(ParallelModule):
43+
"""
44+
A mixed expert module containing shared experts.
45+
"""
46+
47+
def __init__(self, config):
48+
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
49+
50+
def setup_process_groups(
51+
self,
52+
moe_dp_group: ProcessGroup,
53+
ep_group: ProcessGroup,
54+
):
55+
assert moe_dp_group is not None
56+
assert ep_group is not None
57+
58+
self.ep_size = dist.get_world_size(ep_group)
59+
self.ep_rank = dist.get_rank(ep_group)
60+
self.num_experts = self.config.n_routed_experts
61+
assert self.num_experts % self.ep_size == 0
62+
63+
self.ep_group = ep_group
64+
self.num_experts_per_ep = self.num_experts // self.ep_size
65+
self.experts_per_rank = self.num_experts_per_ep
66+
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
67+
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
68+
69+
set_tensors_to_none(self.experts, exclude=set(held_experts))
70+
71+
# setup moe_dp group
72+
self.moe_dp_group = moe_dp_group
73+
self.moe_dp_size = dist.get_world_size(moe_dp_group)
74+
75+
for p in self.experts.parameters():
76+
set_moe_tensor_ep_group(p, ep_group)
77+
78+
@staticmethod
79+
def from_native_module(
80+
module,
81+
moe_dp_group: ProcessGroup,
82+
ep_group: ProcessGroup,
83+
*args,
84+
**kwargs,
85+
) -> "EpDeepseekV3MoE":
86+
LazyInitContext.materialize(module)
87+
if module.__class__.__name__ == "DeepseekV3MLP":
88+
return module
89+
module.__class__ = EpDeepseekV3MoE
90+
module.setup_process_groups(moe_dp_group, ep_group)
91+
return module
92+
93+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
94+
identity = hidden_states
95+
orig_shape = hidden_states.shape
96+
topk_idx, topk_weight = self.gate(hidden_states)
97+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
98+
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
99+
if self.config.n_shared_experts is not None:
100+
y = y + self.shared_experts(identity)
101+
return y
102+
103+
def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
104+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
105+
cnts.scatter_(1, topk_ids, 1)
106+
tokens_per_expert = cnts.sum(dim=0)
107+
idxs = topk_ids.view(-1).argsort()
108+
sorted_tokens = x[idxs // topk_ids.shape[1]]
109+
if self.ep_size > 1:
110+
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
111+
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
112+
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)
113+
114+
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
115+
input_split_sizes = tokens_per_ep_rank.tolist()
116+
117+
gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
118+
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
119+
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
120+
s = 0
121+
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
122+
gatherd_idxs[s : s + k] = i % self.experts_per_rank
123+
s += k
124+
gatherd_idxs = gatherd_idxs.argsort()
125+
sorted_tokens = gathered_tokens[gatherd_idxs]
126+
tokens_per_expert = tokens_per_expert_post_gather
127+
128+
# moe-dp related code
129+
activate_experts = tokens_per_expert_post_gather > 0
130+
activate_experts = activate_experts.int()
131+
dist.all_reduce(activate_experts, group=self.moe_dp_group)
132+
133+
# ep related code
134+
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)
135+
136+
tokens_per_expert = tokens_per_expert.cpu().numpy()
137+
138+
outputs = []
139+
start_idx = 0
140+
for i, num_tokens in enumerate(tokens_per_expert):
141+
end_idx = start_idx + num_tokens
142+
if num_tokens == 0:
143+
continue
144+
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
145+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
146+
# moe-dp related code
147+
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
148+
expert_out = expert(tokens_for_this_expert)
149+
# moe-dp related code
150+
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
151+
outputs.append(expert_out)
152+
start_idx = end_idx
153+
154+
if len(outputs) > 0:
155+
outs = torch.cat(outputs, dim=0)
156+
else:
157+
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
158+
outs = sorted_tokens
159+
160+
if self.ep_size > 1:
161+
outs = EPGradScalerOut.apply(outs, self.ep_size)
162+
new_x = torch.empty_like(outs)
163+
new_x[gatherd_idxs] = outs
164+
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
165+
outs = gathered_tokens
166+
167+
new_x = torch.empty_like(outs)
168+
new_x[idxs] = outs
169+
final_out = (
170+
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
171+
.sum(dim=1)
172+
.type(new_x.dtype)
173+
)
174+
175+
return final_out

colossalai/shardformer/policies/auto_policy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@ class PolicyLocation:
167167
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
168168
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
169169
),
170+
# DeepseekV3
171+
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
172+
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
173+
),
174+
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
175+
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
176+
),
170177
# Falcon
171178
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
172179
file_name="falcon", class_name="FalconModelPolicy"
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Callable, Dict, List, Union
2+
3+
import torch.nn as nn
4+
5+
from colossalai.shardformer.layer import FusedRMSNorm
6+
from colossalai.shardformer.modeling.deepseek_v3 import EpDeepseekV3MoE
7+
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
8+
9+
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
10+
11+
12+
class DeepseekV3Policy(Policy):
13+
def config_sanity_check(self):
14+
assert not self.shard_config.enable_tensor_parallelism, "DeepSeekV3 does not support tensor parallelism"
15+
assert self.shard_config.pipeline_stage_manager is None, "DeepSeekV3 does not support pipeline parallelism"
16+
assert not self.shard_config.enable_sequence_parallelism, "DeepSeekV3 does not support sequence parallelism"
17+
18+
def preprocess(self):
19+
return self.model
20+
21+
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
22+
23+
policy = {}
24+
25+
if self.shard_config.ep_group:
26+
# expert parallel
27+
self.append_or_create_submodule_replacement(
28+
description=[
29+
SubModuleReplacementDescription(
30+
suffix="mlp",
31+
target_module=EpDeepseekV3MoE,
32+
kwargs={
33+
"ep_group": self.shard_config.ep_group,
34+
"moe_dp_group": self.shard_config.moe_dp_group,
35+
},
36+
)
37+
],
38+
policy=policy,
39+
target_key="DeepseekV3DecoderLayer",
40+
)
41+
42+
# optimization configuration
43+
if self.shard_config.enable_fused_normalization:
44+
# TODO: prevent casting to fp32
45+
self.append_or_create_submodule_replacement(
46+
description=[
47+
SubModuleReplacementDescription(
48+
suffix="input_layernorm",
49+
target_module=FusedRMSNorm,
50+
),
51+
SubModuleReplacementDescription(
52+
suffix="post_attention_layernorm",
53+
target_module=FusedRMSNorm,
54+
),
55+
],
56+
policy=policy,
57+
target_key="DeepseekV3DecoderLayer",
58+
)
59+
60+
self.append_or_create_submodule_replacement(
61+
description=SubModuleReplacementDescription(
62+
suffix="norm",
63+
target_module=FusedRMSNorm,
64+
),
65+
policy=policy,
66+
target_key="DeepseekV3Model",
67+
)
68+
69+
return policy
70+
71+
def postprocess(self):
72+
return self.model
73+
74+
75+
class DeepseekV3ModelPolicy(DeepseekV3Policy):
76+
pass
77+
78+
79+
class DeepseekV3ForCausalLMPolicy(DeepseekV3Policy):
80+
pass

tests/kit/model_zoo/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .chatglm2 import *
66
from .command import *
77
from .deepseek import *
8+
from .deepseek_v3 import *
89
from .falcon import *
910
from .gpt import *
1011
from .gptj import *
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# modified from tests/kit/model_zoo/transformers/mistral.py
2+
import torch
3+
import transformers
4+
from transformers import AutoConfig
5+
6+
from ..registry import ModelAttribute, model_zoo
7+
8+
# ===============================
9+
# Register single-sentence Mixtral
10+
# ===============================
11+
12+
13+
def data_gen():
14+
# Generated from following code snippet
15+
#
16+
# from transformers import AutoModelForCausalLM, AutoTokenizer
17+
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
18+
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
19+
# tokenized_input = tokenizer([input], return_tensors="pt")
20+
# input_ids = tokenized_input['input_ids']
21+
# attention_mask = tokenized_input['attention_mask']
22+
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
23+
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
24+
return dict(input_ids=input_ids, attention_mask=attention_mask)
25+
26+
27+
def data_gen_for_lm():
28+
# LM data gen
29+
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
30+
data = data_gen()
31+
data["labels"] = data["input_ids"].clone()
32+
return data
33+
34+
35+
# define output transform function
36+
output_transform_fn = lambda x: x
37+
38+
# define loss function
39+
loss_fn = lambda x: x[0].mean()
40+
loss_fn_for_lm = lambda x: x.loss
41+
42+
43+
def init_deepseek():
44+
45+
config = AutoConfig.from_pretrained(
46+
"deepseek-ai/DeepSeek-V3",
47+
hidden_size=128,
48+
intermediate_size=320,
49+
kv_lora_rank=4,
50+
moe_intermediate_size=32,
51+
num_attention_heads=4,
52+
num_experts_per_tok=4,
53+
n_group=4,
54+
num_hidden_layers=3,
55+
num_key_value_heads=4,
56+
first_k_dense_replace=1,
57+
q_lora_rank=8,
58+
torch_dtype="bfloat16",
59+
n_routed_experts=16,
60+
topk_group=2,
61+
v_head_dim=32,
62+
qk_nope_head_dim=32,
63+
qk_rope_head_dim=32,
64+
trust_remote_code=True,
65+
vocab_size=2048,
66+
)
67+
68+
if hasattr(config, "pad_token_id"):
69+
config.pad_token_id = config.eos_token_id
70+
model = transformers.AutoModelForCausalLM.from_config(config, trust_remote_code=True)
71+
72+
return model
73+
74+
75+
model_zoo.register(
76+
name="transformers_deepseek_v3",
77+
model_fn=init_deepseek,
78+
data_gen_fn=data_gen_for_lm,
79+
output_transform_fn=output_transform_fn,
80+
loss_fn=loss_fn_for_lm,
81+
model_attribute=ModelAttribute(has_control_flow=True),
82+
)

tests/test_shardformer/test_model/_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def _criterion(outputs, inputs):
223223
for k, v in data.items():
224224
unshard_test_data[k] = data[k].clone()
225225

226-
sharded_model.train()
227226
if booster.plugin.stage_manager is not None:
228227
for k, v in shard_test_data.items():
229228
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
@@ -248,7 +247,6 @@ def _criterion(outputs, inputs):
248247
sharded_loss = criterion(sharded_output)
249248
sharded_optimizer.backward(sharded_loss)
250249

251-
org_model.train()
252250
if booster.plugin.stage_manager is not None:
253251
for k, v in unshard_test_data.items():
254252
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:

0 commit comments

Comments
 (0)