Skip to content

Commit 2b415e5

Browse files
[shardformer] support ep for deepseek v3 (#6185)
* [feature] support ep for deepseek v3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * [shardformer] fix deepseek v3 init * [lazy] fit lora for lazy init * [example] support npu for deepseek v3 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 17062c8 commit 2b415e5

File tree

13 files changed

+612
-22
lines changed

13 files changed

+612
-22
lines changed

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
HybridParallelPlugin,
2020
HybridParallelZeroOptimizer,
2121
get_param_info,
22-
reinitialize_optimizer,
2322
)
2423
from colossalai.checkpoint_io import MoECheckpointIO
2524
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
@@ -468,18 +467,13 @@ def configure(
468467
use_fp8=self.use_fp8,
469468
)
470469
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
471-
if self.ep_size > 1:
472-
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
473-
# but the optimizer is not aware of ep, so we need to update the optimizer
474-
reinitialize_optimizer(optimizer, model)
475-
476470
if self.zero_stage == 0:
477471
is_zero = False
478472
if self.precision in ["fp16", "bf16"]:
479473
optimizer = HybridParallelAMPOptimizer(
480474
optimizer,
481475
model,
482-
use_pipeline=self.enable_pipeline_parallelism,
476+
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
483477
param_info=param_info,
484478
precision=self.precision,
485479
max_norm=self.max_norm,
@@ -489,7 +483,7 @@ def configure(
489483
optimizer = HybridParallelNaiveOptimizer(
490484
optimizer,
491485
model,
492-
use_pipeline=self.enable_pipeline_parallelism,
486+
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
493487
param_info=param_info,
494488
max_norm=self.max_norm,
495489
pp_process_group=self.pp_group,
@@ -507,7 +501,7 @@ def configure(
507501
optimizer = MoeHybridParallelZeroOptimizer(
508502
optimizer,
509503
model,
510-
use_pipeline=self.enable_pipeline_parallelism,
504+
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
511505
param_info=param_info,
512506
dp_process_group=self.mixed_dp_group,
513507
tp_process_group=self.tp_group,

colossalai/cluster/process_group_mesh.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def destroy_mesh_process_groups(self):
6464
system resources.
6565
"""
6666
for group in self._ranks_to_group.values():
67-
dist.destroy_process_group(group)
67+
try:
68+
dist.destroy_process_group(group)
69+
except ValueError:
70+
pass
6871

6972
# Manually clear all process groups to save memory
7073
gc.collect()

colossalai/lazy/lazy_init.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _data_tolist(tensor: torch.Tensor) -> list:
104104
return tensor.data.tolist()
105105

106106

107-
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
107+
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad=None) -> torch.Tensor:
108108
"""Convert a lazy tensor's class to target's class, with target's data.
109109
110110
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
@@ -117,13 +117,14 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
117117
Returns:
118118
torch.Tensor: the converted tensor
119119
"""
120+
requires_grad = target.requires_grad if requires_grad is None else requires_grad
120121
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
121122
tensor.__class__ = cls_to_become
122123
if cls_to_become is Parameter:
123124
# to fit UninitializedParameter
124125
delattr(tensor, "_is_param")
125126
tensor.data = target
126-
tensor.requires_grad = target.requires_grad
127+
tensor.requires_grad = requires_grad
127128
# subclass of torch.Tensor does not have tolist() method
128129
# overwrite this method after materialization or distribution
129130
tensor.tolist = MethodType(_data_tolist, tensor)
@@ -212,9 +213,10 @@ def materialize(self) -> torch.Tensor:
212213
Returns:
213214
torch.Tensor: The materialized tensor (self).
214215
"""
216+
requires_grad = self.requires_grad
215217
target = self._materialize_data()
216218
self.clean()
217-
return _convert_cls(self, target)
219+
return _convert_cls(self, target, requires_grad=requires_grad)
218220

219221
def clean(self) -> None:
220222
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
from typing import List, Optional, Tuple, Union
2+
3+
import numpy as np
4+
import torch
5+
import torch.distributed as dist
6+
from torch.distributed import ProcessGroup
7+
from transformers.cache_utils import Cache, DynamicCache
8+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
9+
from transformers.modeling_outputs import BaseModelOutputWithPast
10+
11+
from colossalai.lazy import LazyInitContext
12+
from colossalai.moe._operation import (
13+
DPGradScalerIn,
14+
DPGradScalerOut,
15+
EPGradScalerIn,
16+
EPGradScalerOut,
17+
all_to_all_uneven,
18+
)
19+
from colossalai.shardformer.layer.linear import ParallelModule
20+
from colossalai.shardformer.shard.utils import set_tensors_to_none
21+
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
22+
23+
24+
class EpDeepseekV3MoE(ParallelModule):
25+
"""
26+
A mixed expert module containing shared experts.
27+
"""
28+
29+
def __init__(self, config):
30+
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
31+
32+
def setup_process_groups(
33+
self,
34+
moe_dp_group: ProcessGroup,
35+
ep_group: ProcessGroup,
36+
):
37+
assert moe_dp_group is not None
38+
assert ep_group is not None
39+
40+
self.ep_size = dist.get_world_size(ep_group)
41+
self.ep_rank = dist.get_rank(ep_group)
42+
self.num_experts = self.config.n_routed_experts
43+
assert self.num_experts % self.ep_size == 0
44+
45+
self.ep_group = ep_group
46+
self.num_experts_per_ep = self.num_experts // self.ep_size
47+
self.experts_per_rank = self.num_experts_per_ep
48+
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
49+
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
50+
51+
set_tensors_to_none(self.experts, exclude=set(held_experts))
52+
53+
# setup moe_dp group
54+
self.moe_dp_group = moe_dp_group
55+
self.moe_dp_size = dist.get_world_size(moe_dp_group)
56+
57+
for p in self.experts.parameters():
58+
set_moe_tensor_ep_group(p, ep_group)
59+
60+
@staticmethod
61+
def from_native_module(
62+
module,
63+
moe_dp_group: ProcessGroup,
64+
ep_group: ProcessGroup,
65+
*args,
66+
**kwargs,
67+
) -> "EpDeepseekV3MoE":
68+
if module.__class__.__name__ != "DeepseekV3MLP":
69+
module.__class__ = EpDeepseekV3MoE
70+
module.setup_process_groups(moe_dp_group, ep_group)
71+
LazyInitContext.materialize(module)
72+
return module
73+
74+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
75+
identity = hidden_states
76+
orig_shape = hidden_states.shape
77+
topk_idx, topk_weight = self.gate(hidden_states)
78+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
79+
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
80+
if self.config.n_shared_experts is not None:
81+
y = y + self.shared_experts(identity)
82+
return y
83+
84+
def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
85+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
86+
cnts.scatter_(1, topk_ids, 1)
87+
tokens_per_expert = cnts.sum(dim=0)
88+
idxs = topk_ids.view(-1).argsort()
89+
sorted_tokens = x[idxs // topk_ids.shape[1]]
90+
if self.ep_size > 1:
91+
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
92+
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
93+
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)
94+
95+
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
96+
input_split_sizes = tokens_per_ep_rank.tolist()
97+
98+
gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
99+
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
100+
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
101+
s = 0
102+
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
103+
gatherd_idxs[s : s + k] = i % self.experts_per_rank
104+
s += k
105+
gatherd_idxs = gatherd_idxs.argsort()
106+
sorted_tokens = gathered_tokens[gatherd_idxs]
107+
tokens_per_expert = tokens_per_expert_post_gather
108+
109+
# moe-dp related code
110+
activate_experts = tokens_per_expert_post_gather > 0
111+
activate_experts = activate_experts.int()
112+
dist.all_reduce(activate_experts, group=self.moe_dp_group)
113+
114+
# ep related code
115+
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)
116+
117+
tokens_per_expert = tokens_per_expert.cpu().numpy()
118+
119+
outputs = []
120+
start_idx = 0
121+
for i, num_tokens in enumerate(tokens_per_expert):
122+
end_idx = start_idx + num_tokens
123+
if num_tokens == 0:
124+
continue
125+
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
126+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
127+
# moe-dp related code
128+
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
129+
expert_out = expert(tokens_for_this_expert)
130+
# moe-dp related code
131+
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
132+
outputs.append(expert_out)
133+
start_idx = end_idx
134+
135+
if len(outputs) > 0:
136+
outs = torch.cat(outputs, dim=0)
137+
else:
138+
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
139+
outs = sorted_tokens
140+
141+
if self.ep_size > 1:
142+
outs = EPGradScalerOut.apply(outs, self.ep_size)
143+
new_x = torch.empty_like(outs)
144+
new_x[gatherd_idxs] = outs
145+
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
146+
outs = gathered_tokens
147+
148+
new_x = torch.empty_like(outs)
149+
new_x[idxs] = outs
150+
final_out = (
151+
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
152+
.sum(dim=1)
153+
.type(new_x.dtype)
154+
)
155+
156+
return final_out
157+
158+
159+
def deepseek_v3_model_forward(
160+
self,
161+
input_ids: torch.LongTensor = None,
162+
attention_mask: Optional[torch.Tensor] = None,
163+
position_ids: Optional[torch.LongTensor] = None,
164+
past_key_values: Optional[List[torch.FloatTensor]] = None,
165+
inputs_embeds: Optional[torch.FloatTensor] = None,
166+
use_cache: Optional[bool] = None,
167+
output_attentions: Optional[bool] = None,
168+
output_hidden_states: Optional[bool] = None,
169+
return_dict: Optional[bool] = None,
170+
) -> Union[Tuple, BaseModelOutputWithPast]:
171+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
172+
output_hidden_states = (
173+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
174+
)
175+
use_cache = use_cache if use_cache is not None else self.config.use_cache
176+
177+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
178+
179+
# retrieve input_ids and inputs_embeds
180+
if input_ids is not None and inputs_embeds is not None:
181+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
182+
elif input_ids is not None:
183+
batch_size, seq_length = input_ids.shape[:2]
184+
elif inputs_embeds is not None:
185+
batch_size, seq_length = inputs_embeds.shape[:2]
186+
else:
187+
raise ValueError("You have to specify either input_ids or inputs_embeds")
188+
189+
past_key_values_length = 0
190+
if use_cache:
191+
use_legacy_cache = not isinstance(past_key_values, Cache)
192+
if use_legacy_cache:
193+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
194+
past_key_values_length = past_key_values.get_usable_length(seq_length)
195+
196+
if position_ids is None:
197+
device = input_ids.device if input_ids is not None else inputs_embeds.device
198+
position_ids = torch.arange(
199+
past_key_values_length,
200+
seq_length + past_key_values_length,
201+
dtype=torch.long,
202+
device=device,
203+
)
204+
position_ids = position_ids.unsqueeze(0)
205+
206+
if inputs_embeds is None:
207+
inputs_embeds = self.embed_tokens(input_ids)
208+
209+
if self._use_flash_attention_2:
210+
# 2d mask is passed through the layers
211+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
212+
else:
213+
# 4d mask is passed through the layers
214+
attention_mask = _prepare_4d_causal_attention_mask(
215+
attention_mask,
216+
(batch_size, seq_length),
217+
inputs_embeds,
218+
past_key_values_length,
219+
)
220+
221+
# embed positions
222+
hidden_states = inputs_embeds
223+
224+
# decoder layers
225+
all_hidden_states = () if output_hidden_states else None
226+
all_self_attns = () if output_attentions else None
227+
next_decoder_cache = None
228+
229+
for i, decoder_layer in enumerate(self.layers):
230+
if output_hidden_states:
231+
all_hidden_states += (hidden_states,)
232+
233+
if self.gradient_checkpointing and i > 0:
234+
layer_outputs = self._gradient_checkpointing_func(
235+
decoder_layer.__call__,
236+
hidden_states,
237+
attention_mask,
238+
position_ids,
239+
past_key_values,
240+
output_attentions,
241+
use_cache,
242+
)
243+
else:
244+
layer_outputs = decoder_layer(
245+
hidden_states,
246+
attention_mask=attention_mask,
247+
position_ids=position_ids,
248+
past_key_value=past_key_values,
249+
output_attentions=output_attentions,
250+
use_cache=use_cache,
251+
)
252+
253+
hidden_states = layer_outputs[0]
254+
255+
if use_cache:
256+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
257+
258+
if output_attentions:
259+
all_self_attns += (layer_outputs[1],)
260+
261+
hidden_states = self.norm(hidden_states)
262+
263+
# add hidden states from the last decoder layer
264+
if output_hidden_states:
265+
all_hidden_states += (hidden_states,)
266+
267+
next_cache = None
268+
if use_cache:
269+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
270+
if not return_dict:
271+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
272+
return BaseModelOutputWithPast(
273+
last_hidden_state=hidden_states,
274+
past_key_values=next_cache,
275+
hidden_states=all_hidden_states,
276+
attentions=all_self_attns,
277+
)

colossalai/shardformer/policies/auto_policy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@ class PolicyLocation:
167167
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
168168
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
169169
),
170+
# DeepseekV3
171+
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
172+
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
173+
),
174+
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
175+
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
176+
),
170177
# Falcon
171178
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
172179
file_name="falcon", class_name="FalconModelPolicy"

0 commit comments

Comments
 (0)