Skip to content

Commit 8243b2b

Browse files
authored
[Model] Support multi-GPU for Deepseek-v2 (#3080)
This PR supports tensor parallelism for Deepseek-v2 model.
1 parent 6faf68e commit 8243b2b

File tree

2 files changed

+95
-18
lines changed

2 files changed

+95
-18
lines changed

python/mlc_llm/model/deepseek/deepseek_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,10 @@ def _set(layer, hint):
307307
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
308308
out = self.input_layernorm(hidden_states)
309309
out = self.self_attn(out, paged_kv_cache, layer_id)
310-
hidden_states = self._apply_residual(hidden_states, residual=out)
310+
hidden_states = self._apply_residual(out, residual=hidden_states)
311311
out = self.post_attention_layernorm(hidden_states)
312312
out = self.mlp(out) # type: ignore[operator]
313-
hidden_states = self._apply_residual(hidden_states, residual=out)
313+
hidden_states = self._apply_residual(out, residual=hidden_states)
314314
return hidden_states
315315

316316
def _apply_residual(self, out, residual):

python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mlc_llm.nn import PagedKVCache, RopeMode
1717
from mlc_llm.nn.expert import MixtralExperts
1818
from mlc_llm.support import logging
19+
from mlc_llm.support import tensor_parallel as tp
1920
from mlc_llm.support.config import ConfigBase
2021
from mlc_llm.support.style import bold
2122

@@ -79,20 +80,17 @@ def __post_init__(self):
7980
logger.info(
8081
"%s defaults to %d",
8182
bold("prefill_chunk_size"),
82-
min(self.context_window_size, 8192),
83+
min(self.context_window_size, 2048),
8384
)
84-
self.prefill_chunk_size = min(self.context_window_size, 8192)
85+
self.prefill_chunk_size = min(self.context_window_size, 2048)
8586
elif self.prefill_chunk_size > self.context_window_size:
8687
logger.info(
8788
"Overriding %s from %d to %d",
8889
bold("prefill_chunk_size"),
8990
self.prefill_chunk_size,
90-
min(self.context_window_size, 8192),
91+
min(self.context_window_size, 2048),
9192
)
92-
self.prefill_chunk_size = min(self.context_window_size, 8192)
93-
94-
if self.tensor_parallel_shards != 1:
95-
raise ValueError("Only support single device at this moment.")
93+
self.prefill_chunk_size = min(self.context_window_size, 2048)
9694

9795

9896
# pylint: disable=invalid-name,missing-docstring,too-many-locals
@@ -102,9 +100,15 @@ class DeepseekV2MLP(nn.Module):
102100
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):
103101
super().__init__()
104102
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
105-
self.intermediate_size = (
103+
intermediate_size = (
106104
config.intermediate_size if intermediate_size is None else intermediate_size
107105
)
106+
if intermediate_size % config.tensor_parallel_shards != 0:
107+
raise ValueError(
108+
f"Cannot split MoE intermediate size {intermediate_size} "
109+
f"evenly to {config.tensor_parallel_shards} GPUs."
110+
)
111+
self.intermediate_size = intermediate_size // config.tensor_parallel_shards
108112

109113
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
110114
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
@@ -173,7 +177,12 @@ def __init__(self, config: DeepseekV2Config):
173177
super().__init__()
174178
self.config = config
175179
self.hidden_size = config.hidden_size
176-
self.num_heads = config.num_attention_heads
180+
if config.num_attention_heads % config.tensor_parallel_shards != 0:
181+
raise ValueError(
182+
f"Cannot split {config.num_attention_heads} attention heads "
183+
f"evenly to {config.tensor_parallel_shards} GPUs."
184+
)
185+
self.num_heads = config.num_attention_heads // config.tensor_parallel_shards
177186

178187
self.rope_theta = config.rope_theta
179188
self.q_lora_rank = config.q_lora_rank
@@ -320,7 +329,12 @@ def __init__(self, config: DeepseekV2Config):
320329

321330
self.gate = nn.Linear(config.hidden_size, self.num_routed_experts, bias=False)
322331
self.norm_topk_prob = config.norm_topk_prob
323-
self.moe_intermediate_size = config.moe_intermediate_size
332+
if config.moe_intermediate_size % config.tensor_parallel_shards != 0:
333+
raise ValueError(
334+
f"Cannot split MoE intermediate size {config.moe_intermediate_size} "
335+
f"evenly to {config.tensor_parallel_shards} GPUs."
336+
)
337+
self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards
324338

325339
self.moe_gate_up_proj = MixtralExperts(
326340
self.num_routed_experts,
@@ -333,8 +347,9 @@ def __init__(self, config: DeepseekV2Config):
333347
out_features=config.hidden_size,
334348
)
335349

336-
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
337-
self.shared_experts = DeepseekV2MLP(config, intermediate_size=intermediate_size)
350+
self.shared_experts = DeepseekV2MLP(
351+
config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
352+
)
338353

339354
def forward(self, x: Tensor):
340355
def _expert_forward(x: Tensor, indptr: Tensor):
@@ -404,15 +419,72 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int):
404419
config.hidden_size, -1, config.rms_norm_eps, bias=False
405420
)
406421

422+
def _set_tp():
423+
def _set(layer, hint):
424+
layer.attrs["shard_strategy"] = hint
425+
426+
if self.self_attn.q_lora_rank is None:
427+
_set(
428+
self.self_attn.q_proj.weight,
429+
tp.ShardSingleDim("_shard_q_weight", dim=0),
430+
)
431+
else:
432+
_set(
433+
self.self_attn.q_b_proj.weight,
434+
tp.ShardSingleDim("_shard_q_b_weight", dim=0),
435+
)
436+
437+
_set(
438+
self.self_attn.kv_b_proj.weight,
439+
tp.ShardSingleDim("_shard_kv_b_weight", dim=0),
440+
)
441+
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
442+
443+
if isinstance(self.mlp, DeepseekV2MoE):
444+
si = self.mlp.shared_experts.intermediate_size
445+
mi = self.mlp.moe_intermediate_size
446+
_set(
447+
self.mlp.shared_experts.gate_up_proj.weight,
448+
tp.ShardSingleDim("_shard_shared_experts_gate_up", segs=[si, si], dim=0),
449+
)
450+
_set(
451+
self.mlp.shared_experts.down_proj.weight,
452+
tp.ShardSingleDim("_shard_shared_experts_down", dim=1),
453+
)
454+
_set(
455+
self.mlp.moe_gate_up_proj.weight,
456+
tp.ShardSingleDim("_shard_moe_gate_up", segs=[mi, mi], dim=1),
457+
)
458+
_set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2))
459+
else:
460+
assert isinstance(self.mlp, DeepseekV2MLP)
461+
si = self.mlp.intermediate_size
462+
_set(
463+
self.mlp.gate_up_proj.weight,
464+
tp.ShardSingleDim("_shard_gate_up", segs=[si, si], dim=0),
465+
)
466+
_set(
467+
self.mlp.down_proj.weight,
468+
tp.ShardSingleDim("_shard_down", dim=1),
469+
)
470+
471+
self.tensor_parallel_shards = config.tensor_parallel_shards
472+
_set_tp()
473+
407474
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
408475
out = self.input_layernorm(hidden_states)
409476
out = self.self_attn(out, paged_kv_cache, layer_id)
410-
hidden_states = hidden_states + out
477+
hidden_states = self._apply_residual(out, residual=hidden_states)
411478
out = self.post_attention_layernorm(hidden_states)
412479
out = self.mlp(out) # type: ignore[operator]
413-
hidden_states = hidden_states + out
480+
hidden_states = self._apply_residual(out, residual=hidden_states)
414481
return hidden_states
415482

483+
def _apply_residual(self, out, residual):
484+
if self.tensor_parallel_shards > 1:
485+
return op.ccl_allreduce(out, "sum") + residual
486+
return out + residual
487+
416488

417489
class DeepseekV2Model(nn.Module):
418490
def __init__(self, config: DeepseekV2Config):
@@ -446,6 +518,7 @@ def __init__(self, config: DeepseekV2Config):
446518
self.rms_norm_eps = config.rms_norm_eps
447519
self.rope_theta = config.rope_theta
448520
self.vocab_size = config.vocab_size
521+
self.tensor_parallel_shards = config.tensor_parallel_shards
449522

450523
def to(self, dtype: Optional[str] = None):
451524
super().to(dtype=dtype)
@@ -469,6 +542,8 @@ def batch_forward(
469542
return logits
470543

471544
def embed(self, input_ids: Tensor):
545+
if self.tensor_parallel_shards > 1:
546+
input_ids = op.ccl_broadcast_from_worker0(input_ids)
472547
return self.model.embed_tokens(input_ids)
473548

474549
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
@@ -497,6 +572,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
497572
def batch_prefill(
498573
self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache
499574
):
575+
if self.tensor_parallel_shards > 1:
576+
logit_positions = op.ccl_broadcast_from_worker0(logit_positions)
500577
logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)
501578
return logits, paged_kv_cache
502579

@@ -523,8 +600,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
523600
page_size=page_size,
524601
support_sliding_window=support_sliding_window,
525602
num_hidden_layers=self.num_hidden_layers,
526-
num_attention_heads=self.num_attention_heads,
527-
num_key_value_heads=self.num_key_value_heads,
603+
num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,
604+
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
528605
head_dim=256,
529606
rope_mode=RopeMode.NONE,
530607
rope_scale=1,

0 commit comments

Comments
 (0)