16
16
from mlc_llm .nn import PagedKVCache , RopeMode
17
17
from mlc_llm .nn .expert import MixtralExperts
18
18
from mlc_llm .support import logging
19
+ from mlc_llm .support import tensor_parallel as tp
19
20
from mlc_llm .support .config import ConfigBase
20
21
from mlc_llm .support .style import bold
21
22
@@ -79,20 +80,17 @@ def __post_init__(self):
79
80
logger .info (
80
81
"%s defaults to %d" ,
81
82
bold ("prefill_chunk_size" ),
82
- min (self .context_window_size , 8192 ),
83
+ min (self .context_window_size , 2048 ),
83
84
)
84
- self .prefill_chunk_size = min (self .context_window_size , 8192 )
85
+ self .prefill_chunk_size = min (self .context_window_size , 2048 )
85
86
elif self .prefill_chunk_size > self .context_window_size :
86
87
logger .info (
87
88
"Overriding %s from %d to %d" ,
88
89
bold ("prefill_chunk_size" ),
89
90
self .prefill_chunk_size ,
90
- min (self .context_window_size , 8192 ),
91
+ min (self .context_window_size , 2048 ),
91
92
)
92
- self .prefill_chunk_size = min (self .context_window_size , 8192 )
93
-
94
- if self .tensor_parallel_shards != 1 :
95
- raise ValueError ("Only support single device at this moment." )
93
+ self .prefill_chunk_size = min (self .context_window_size , 2048 )
96
94
97
95
98
96
# pylint: disable=invalid-name,missing-docstring,too-many-locals
@@ -102,9 +100,15 @@ class DeepseekV2MLP(nn.Module):
102
100
def __init__ (self , config : DeepseekV2Config , hidden_size = None , intermediate_size = None ):
103
101
super ().__init__ ()
104
102
self .hidden_size = config .hidden_size if hidden_size is None else hidden_size
105
- self . intermediate_size = (
103
+ intermediate_size = (
106
104
config .intermediate_size if intermediate_size is None else intermediate_size
107
105
)
106
+ if intermediate_size % config .tensor_parallel_shards != 0 :
107
+ raise ValueError (
108
+ f"Cannot split MoE intermediate size { intermediate_size } "
109
+ f"evenly to { config .tensor_parallel_shards } GPUs."
110
+ )
111
+ self .intermediate_size = intermediate_size // config .tensor_parallel_shards
108
112
109
113
self .gate_up_proj = nn .Linear (self .hidden_size , 2 * self .intermediate_size , bias = False )
110
114
self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias = False )
@@ -173,7 +177,12 @@ def __init__(self, config: DeepseekV2Config):
173
177
super ().__init__ ()
174
178
self .config = config
175
179
self .hidden_size = config .hidden_size
176
- self .num_heads = config .num_attention_heads
180
+ if config .num_attention_heads % config .tensor_parallel_shards != 0 :
181
+ raise ValueError (
182
+ f"Cannot split { config .num_attention_heads } attention heads "
183
+ f"evenly to { config .tensor_parallel_shards } GPUs."
184
+ )
185
+ self .num_heads = config .num_attention_heads // config .tensor_parallel_shards
177
186
178
187
self .rope_theta = config .rope_theta
179
188
self .q_lora_rank = config .q_lora_rank
@@ -320,7 +329,12 @@ def __init__(self, config: DeepseekV2Config):
320
329
321
330
self .gate = nn .Linear (config .hidden_size , self .num_routed_experts , bias = False )
322
331
self .norm_topk_prob = config .norm_topk_prob
323
- self .moe_intermediate_size = config .moe_intermediate_size
332
+ if config .moe_intermediate_size % config .tensor_parallel_shards != 0 :
333
+ raise ValueError (
334
+ f"Cannot split MoE intermediate size { config .moe_intermediate_size } "
335
+ f"evenly to { config .tensor_parallel_shards } GPUs."
336
+ )
337
+ self .moe_intermediate_size = config .moe_intermediate_size // config .tensor_parallel_shards
324
338
325
339
self .moe_gate_up_proj = MixtralExperts (
326
340
self .num_routed_experts ,
@@ -333,8 +347,9 @@ def __init__(self, config: DeepseekV2Config):
333
347
out_features = config .hidden_size ,
334
348
)
335
349
336
- intermediate_size = config .moe_intermediate_size * config .n_shared_experts
337
- self .shared_experts = DeepseekV2MLP (config , intermediate_size = intermediate_size )
350
+ self .shared_experts = DeepseekV2MLP (
351
+ config , intermediate_size = config .moe_intermediate_size * config .n_shared_experts
352
+ )
338
353
339
354
def forward (self , x : Tensor ):
340
355
def _expert_forward (x : Tensor , indptr : Tensor ):
@@ -404,15 +419,72 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int):
404
419
config .hidden_size , - 1 , config .rms_norm_eps , bias = False
405
420
)
406
421
422
+ def _set_tp ():
423
+ def _set (layer , hint ):
424
+ layer .attrs ["shard_strategy" ] = hint
425
+
426
+ if self .self_attn .q_lora_rank is None :
427
+ _set (
428
+ self .self_attn .q_proj .weight ,
429
+ tp .ShardSingleDim ("_shard_q_weight" , dim = 0 ),
430
+ )
431
+ else :
432
+ _set (
433
+ self .self_attn .q_b_proj .weight ,
434
+ tp .ShardSingleDim ("_shard_q_b_weight" , dim = 0 ),
435
+ )
436
+
437
+ _set (
438
+ self .self_attn .kv_b_proj .weight ,
439
+ tp .ShardSingleDim ("_shard_kv_b_weight" , dim = 0 ),
440
+ )
441
+ _set (self .self_attn .o_proj .weight , tp .ShardSingleDim ("_shard_o" , dim = 1 ))
442
+
443
+ if isinstance (self .mlp , DeepseekV2MoE ):
444
+ si = self .mlp .shared_experts .intermediate_size
445
+ mi = self .mlp .moe_intermediate_size
446
+ _set (
447
+ self .mlp .shared_experts .gate_up_proj .weight ,
448
+ tp .ShardSingleDim ("_shard_shared_experts_gate_up" , segs = [si , si ], dim = 0 ),
449
+ )
450
+ _set (
451
+ self .mlp .shared_experts .down_proj .weight ,
452
+ tp .ShardSingleDim ("_shard_shared_experts_down" , dim = 1 ),
453
+ )
454
+ _set (
455
+ self .mlp .moe_gate_up_proj .weight ,
456
+ tp .ShardSingleDim ("_shard_moe_gate_up" , segs = [mi , mi ], dim = 1 ),
457
+ )
458
+ _set (self .mlp .moe_down_proj .weight , tp .ShardSingleDim ("_shard_moe_mlp_down" , dim = 2 ))
459
+ else :
460
+ assert isinstance (self .mlp , DeepseekV2MLP )
461
+ si = self .mlp .intermediate_size
462
+ _set (
463
+ self .mlp .gate_up_proj .weight ,
464
+ tp .ShardSingleDim ("_shard_gate_up" , segs = [si , si ], dim = 0 ),
465
+ )
466
+ _set (
467
+ self .mlp .down_proj .weight ,
468
+ tp .ShardSingleDim ("_shard_down" , dim = 1 ),
469
+ )
470
+
471
+ self .tensor_parallel_shards = config .tensor_parallel_shards
472
+ _set_tp ()
473
+
407
474
def forward (self , hidden_states : Tensor , paged_kv_cache : PagedKVCache , layer_id : int ):
408
475
out = self .input_layernorm (hidden_states )
409
476
out = self .self_attn (out , paged_kv_cache , layer_id )
410
- hidden_states = hidden_states + out
477
+ hidden_states = self . _apply_residual ( out , residual = hidden_states )
411
478
out = self .post_attention_layernorm (hidden_states )
412
479
out = self .mlp (out ) # type: ignore[operator]
413
- hidden_states = hidden_states + out
480
+ hidden_states = self . _apply_residual ( out , residual = hidden_states )
414
481
return hidden_states
415
482
483
+ def _apply_residual (self , out , residual ):
484
+ if self .tensor_parallel_shards > 1 :
485
+ return op .ccl_allreduce (out , "sum" ) + residual
486
+ return out + residual
487
+
416
488
417
489
class DeepseekV2Model (nn .Module ):
418
490
def __init__ (self , config : DeepseekV2Config ):
@@ -446,6 +518,7 @@ def __init__(self, config: DeepseekV2Config):
446
518
self .rms_norm_eps = config .rms_norm_eps
447
519
self .rope_theta = config .rope_theta
448
520
self .vocab_size = config .vocab_size
521
+ self .tensor_parallel_shards = config .tensor_parallel_shards
449
522
450
523
def to (self , dtype : Optional [str ] = None ):
451
524
super ().to (dtype = dtype )
@@ -469,6 +542,8 @@ def batch_forward(
469
542
return logits
470
543
471
544
def embed (self , input_ids : Tensor ):
545
+ if self .tensor_parallel_shards > 1 :
546
+ input_ids = op .ccl_broadcast_from_worker0 (input_ids )
472
547
return self .model .embed_tokens (input_ids )
473
548
474
549
def prefill (self , input_embed : Tensor , paged_kv_cache : PagedKVCache ):
@@ -497,6 +572,8 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
497
572
def batch_prefill (
498
573
self , input_embeds : Tensor , logit_positions : Tensor , paged_kv_cache : PagedKVCache
499
574
):
575
+ if self .tensor_parallel_shards > 1 :
576
+ logit_positions = op .ccl_broadcast_from_worker0 (logit_positions )
500
577
logits = self .batch_forward (input_embeds , paged_kv_cache , logit_positions )
501
578
return logits , paged_kv_cache
502
579
@@ -523,8 +600,8 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
523
600
page_size = page_size ,
524
601
support_sliding_window = support_sliding_window ,
525
602
num_hidden_layers = self .num_hidden_layers ,
526
- num_attention_heads = self .num_attention_heads ,
527
- num_key_value_heads = self .num_key_value_heads ,
603
+ num_attention_heads = self .num_attention_heads // self . tensor_parallel_shards ,
604
+ num_key_value_heads = self .num_key_value_heads // self . tensor_parallel_shards ,
528
605
head_dim = 256 ,
529
606
rope_mode = RopeMode .NONE ,
530
607
rope_scale = 1 ,
0 commit comments