@@ -241,7 +241,7 @@ def apply(
241
241
# Maybe extra args
242
242
def set_dispatch_combine (self , dispatch_combine : FusedMoEQuantizeDispatchCombine ) -> bool :
243
243
block_m = MOE_DP_CHUNK_SIZE * (self .moe .ep_size // self .moe .dp_size )
244
- print (f"block_m = { block_m } " )
244
+ # print(f"block_m = {block_m}")
245
245
246
246
experts = TritonExperts (
247
247
use_fp8_w8a8 = False ,
@@ -550,8 +550,8 @@ def __init__(
550
550
self .ep_size = 1
551
551
self .local_num_experts = self .global_num_experts
552
552
self .expert_map = None
553
+ #self.global_num_experts = num_experts redundant?
553
554
self .top_k = top_k
554
- self .global_num_experts = num_experts
555
555
556
556
assert intermediate_size % self .tp_size == 0
557
557
self .hidden_size = hidden_size
@@ -571,11 +571,12 @@ def __init__(
571
571
if self .scoring_func != "softmax" and not self .use_grouped_topk :
572
572
raise ValueError ("Only softmax scoring function is supported for "
573
573
"non-grouped topk." )
574
+
574
575
if current_platform .is_hpu ():
575
576
from vllm_hpu_extension .ops import DynamicFusedMOE
576
577
self .hpu_fused_moe = DynamicFusedMOE (self .global_num_experts )
577
578
578
- print (f"params dtype= { params_dtype } " )
579
+ # print(f"params dtype= {params_dtype}")
579
580
580
581
moe = MoEConfig (
581
582
num_experts = self .global_num_experts ,
@@ -604,59 +605,59 @@ def __init__(
604
605
self .quant_method = quant_method
605
606
606
607
# TODO: move to method?
607
- if self .dp_size > 1 :
608
- if True :
609
- max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size
610
- world_size = moe .ep_size
611
- dp_size = moe .ep_size // moe .dp_size # dp_size actually means TP.
612
- rank = moe .ep_rank
608
+ if False and self .dp_size > 1 :
609
+ max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size
610
+ world_size = moe .ep_size
611
+ dp_size = moe .ep_size // moe .dp_size # dp_size actually means TP.
612
+ rank = moe .ep_rank
613
613
614
+ if False :
614
615
print (f"max num = { max_num_tokens } " )
615
616
print (f"world size = { world_size } " )
616
617
print (f"moe ep size = { moe .ep_size } " )
617
618
print (f"moe dp size = { moe .dp_size } " )
618
619
print (f"dp size = { dp_size } " )
619
620
print (f"rank= { rank } " )
620
621
621
- all_to_all = get_all_to_all (
622
- max_num_tokens = max_num_tokens ,
623
- num_experts = moe .num_experts ,
624
- experts_per_token = moe .experts_per_token , # topk
625
- rank = rank ,
626
- world_size = world_size ,
627
- dp_size = dp_size ,
628
- hidden_dim = moe .hidden_dim ,
629
- hidden_dim_bytes = moe .hidden_dim * moe .in_dtype .itemsize ,
630
- # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32)
631
- # For per-token: set to sizeof(float32)
632
- hidden_dim_scale_bytes = (
633
- 0
634
- if moe .in_dtype .itemsize != 1
635
- else (
636
- (moe .hidden_dim + moe .block_size - 1 )
637
- // moe .block_size
638
- * torch .float32 .itemsize
639
- )
622
+ all_to_all = get_all_to_all (
623
+ max_num_tokens = max_num_tokens ,
624
+ num_experts = moe .num_experts ,
625
+ experts_per_token = moe .experts_per_token , # topk
626
+ rank = rank ,
627
+ world_size = world_size ,
628
+ dp_size = dp_size ,
629
+ hidden_dim = moe .hidden_dim ,
630
+ hidden_dim_bytes = moe .hidden_dim * moe .in_dtype .itemsize ,
631
+ # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32)
632
+ # For per-token: set to sizeof(float32)
633
+ hidden_dim_scale_bytes = (
634
+ 0
635
+ if moe .in_dtype .itemsize != 1
636
+ else (
637
+ (moe .hidden_dim + moe .block_size - 1 )
638
+ // moe .block_size
639
+ * torch .float32 .itemsize
640
640
)
641
641
)
642
+ )
642
643
643
- dispatch_combine = PplxDispatchCombine (
644
- all_to_all ,
645
- max_num_tokens ,
646
- world_size ,
647
- dp_size ,
648
- rank , # just for debugging
649
- moe .in_dtype ,
650
- )
651
- else :
652
- dispatch_combine = StandardDispatchCombine (
653
- moe .in_dtype ,
654
- quant_config .weight_block_size if quant_config is not None else None ,
655
- )
644
+ dispatch_combine = PplxDispatchCombine (
645
+ all_to_all ,
646
+ max_num_tokens ,
647
+ world_size ,
648
+ dp_size ,
649
+ rank , # just for debugging
650
+ moe .in_dtype ,
651
+ )
656
652
657
653
success = self .quant_method .set_dispatch_combine (dispatch_combine )
658
654
if not success :
659
655
logger .warning ("DP+EP not supported for %s." , type (self .quant_method ))
656
+ else :
657
+ dispatch_combine = StandardDispatchCombine (
658
+ moe .in_dtype ,
659
+ quant_config .weight_block_size if quant_config is not None else None ,
660
+ )
660
661
661
662
self .apply_router_weight_on_input = apply_router_weight_on_input
662
663
moe_quant_params = {
@@ -1010,7 +1011,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
1010
1011
num_tokens_across_dp = get_forward_context (
1011
1012
).dp_metadata .num_tokens_across_dp
1012
1013
1013
- print (f"max/num/rank_num = { max_tokens_across_dp } /{ num_tokens_across_dp } /{ get_forward_context ().dp_metadata .dp_rank_num_tokens } " )
1014
+ # print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}")
1014
1015
1015
1016
#In this function we define two ranges:
1016
1017
# 1. chunk_range - The current iteration of the loops's range over the DP world tokens
0 commit comments