@@ -587,6 +587,8 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
587
587
588
588
def __init__ (
589
589
self ,
590
+ world_size : int ,
591
+ dp_size : int ,
590
592
max_num_tokens : Optional [int ] = None ,
591
593
use_fp8_w8a8 : bool = False ,
592
594
use_int8_w8a8 : bool = False ,
@@ -603,6 +605,8 @@ def __init__(
603
605
assert not use_int8_w8a16 , "NYI"
604
606
assert not use_int4_w4a16 , "NYI"
605
607
self .max_num_tokens = max_num_tokens
608
+ self .world_size = world_size
609
+ self .dp_size = dp_size
606
610
607
611
def workspace_shapes (
608
612
self ,
@@ -614,10 +618,12 @@ def workspace_shapes(
614
618
num_experts : int ,
615
619
) -> Tuple [int , int , torch .dtype ]:
616
620
assert a .dim () == 2
621
+ num_dp = self .world_size // self .dp_size
617
622
max_num_tokens = a .shape [
618
623
0 ] if self .max_num_tokens is None else self .max_num_tokens
619
- workspace13 = num_experts * max_num_tokens * K
620
- workspace2 = max_num_tokens * N
624
+ #print(f"WORKSPACE {max_num_tokens} {num_dp}")
625
+ workspace13 = num_experts * max_num_tokens * num_dp * K
626
+ workspace2 = max_num_tokens * num_dp * N
621
627
return (workspace13 , workspace2 , a .dtype )
622
628
623
629
def apply (
@@ -648,23 +654,24 @@ def apply(
648
654
else :
649
655
max_num_tokens = self .max_num_tokens
650
656
657
+ num_dp = self .world_size // self .dp_size
651
658
num_experts = global_num_experts
652
659
out = _resize_cache (workspace13 ,
653
- (num_experts , max_num_tokens , hidden_dim ))
660
+ (num_experts , max_num_tokens * num_dp , hidden_dim ))
654
661
num_local_experts = w1 .shape [0 ] #expert_num_tokens.numel()
655
662
assert num_local_experts == w1 .shape [0 ], f"{ num_local_experts } == { w1 .shape [0 ]} "
656
663
657
664
N = w1 .shape [1 ] // 2
658
665
659
666
# Not cudagraph friendly
660
- assert (torch .cuda .is_current_stream_capturing () or
661
- torch .all (expert_num_tokens <= max_num_tokens )), (
662
- f"{ expert_num_tokens } <= { max_num_tokens } " )
667
+ # assert (torch.cuda.is_current_stream_capturing() or
668
+ # torch.all(expert_num_tokens <= max_num_tokens)), (
669
+ # f"{expert_num_tokens} <= {max_num_tokens}")
663
670
664
671
for expert in range (num_local_experts ):
665
672
# Indexing expert_num_tokens doesn't work w/cudagraphs
666
- if torch .cuda .is_current_stream_capturing ():
667
- num = max_num_tokens
673
+ if True or torch .cuda .is_current_stream_capturing ():
674
+ num = max_num_tokens * num_dp
668
675
else :
669
676
num = int (expert_num_tokens [expert ].item ())
670
677
tmp = _resize_cache (workspace2 , (num , N ))
@@ -675,166 +682,6 @@ def apply(
675
682
return out
676
683
677
684
678
- def _apply (
679
- hidden_states : torch .Tensor ,
680
- w1 : torch .Tensor ,
681
- w2 : torch .Tensor ,
682
- topk_ids : torch .Tensor ,
683
- activation : str ,
684
- global_num_experts : int ,
685
- expert_map : Optional [torch .Tensor ],
686
- w1_scale : Optional [torch .Tensor ],
687
- w2_scale : Optional [torch .Tensor ],
688
- w1_zp : Optional [torch .Tensor ],
689
- w2_zp : Optional [torch .Tensor ],
690
- a1q_scale : Optional [torch .Tensor ],
691
- a2_scale : Optional [torch .Tensor ],
692
- workspace13 : torch .Tensor ,
693
- workspace2 : torch .Tensor ,
694
- expert_num_tokens : Optional [torch .Tensor ],
695
- use_fp8_w8a8 : bool ,
696
- use_int8_w8a16 : bool ,
697
- use_int4_w4a16 : bool ,
698
- block_shape : Optional [List [int ]],
699
- ) -> torch .Tensor :
700
- # Check constraints.
701
- if use_int4_w4a16 :
702
- assert hidden_states .shape [- 1 ] // 2 == w1 .shape [
703
- 2 ], "Hidden size mismatch"
704
- else :
705
- assert hidden_states .shape [- 1 ] == w1 .shape [2 ], \
706
- (f"Hidden size mismatch { hidden_states .shape [- 1 ]} "
707
- f"!= { w1 .shape [2 ]} " )
708
-
709
- assert hidden_states .is_contiguous (
710
- ), "Hidden_states must be contiguous"
711
- assert w1 .stride (- 1 ) == 1 , "Stride of last dimension must be 1"
712
- assert w2 .stride (- 1 ) == 1 , "Stride of last dimension must be 1"
713
- assert hidden_states .dtype in [
714
- torch .float32 , torch .float16 , torch .bfloat16 , torch .float8_e4m3fn
715
- ]
716
-
717
- # TODO: num_tokens -> max_num_tokens?
718
- E , num_tokens , N , K , top_k_num = mk ._moe_problem_size (
719
- hidden_states , w1 , w2 , topk_ids )
720
-
721
- assert w1 .shape [0 ] == E
722
- assert w2 .shape [0 ] == E
723
-
724
- config_dtype = get_config_dtype_str (use_fp8_w8a8 = use_fp8_w8a8 ,
725
- use_int8_w8a16 = use_int8_w8a16 ,
726
- use_int4_w4a16 = use_int4_w4a16 ,
727
- dtype = hidden_states .dtype )
728
-
729
- config = try_get_optimal_moe_config (
730
- w1 .shape ,
731
- w2 .shape ,
732
- top_k_num ,
733
- config_dtype ,
734
- num_tokens ,
735
- block_shape = block_shape ,
736
- )
737
-
738
- if hidden_states .dtype == torch .bfloat16 :
739
- compute_type = tl .bfloat16
740
- elif hidden_states .dtype == torch .float16 :
741
- compute_type = tl .float16
742
- elif hidden_states .dtype == torch .float32 :
743
- compute_type = tl .float32
744
- elif hidden_states .dtype == torch .float8_e4m3fn :
745
- compute_type = tl .bfloat16
746
- else :
747
- raise ValueError (
748
- f"Unsupported compute_type: { hidden_states .dtype } " )
749
-
750
- #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
751
- # We can reuse the memory between these because by the time we need
752
- # cache3, we're done with cache1
753
- intermediate_cache1 = _resize_cache (workspace13 , (E , num_tokens , N ))
754
- intermediate_cache2 = _resize_cache (workspace2 ,
755
- (E , num_tokens , N // 2 ))
756
- intermediate_cache3 = _resize_cache (workspace13 , (E , num_tokens , K ))
757
-
758
- # MM1
759
- invoke_moe_batched_triton_kernel (A = hidden_states ,
760
- B = w1 ,
761
- C = intermediate_cache1 ,
762
- expert_num_tokens = expert_num_tokens ,
763
- compute_type = compute_type ,
764
- A_scale = a1q_scale ,
765
- B_scale = w1_scale ,
766
- B_zp = w1_zp ,
767
- use_fp8_w8a8 = use_fp8_w8a8 ,
768
- use_int8_w8a16 = use_int8_w8a16 ,
769
- use_int4_w4a16 = use_int4_w4a16 ,
770
- config = config ,
771
- block_shape = block_shape )
772
-
773
- # Fix activations
774
- assert activation == "silu"
775
- invoke_batched_silu_and_mul (output = intermediate_cache2 ,
776
- input = intermediate_cache1 ,
777
- expert_num_tokens = expert_num_tokens )
778
-
779
- #qintermediate_cache2 = intermediate_cache2
780
- a2q_scale = a2_scale
781
- # TODO (varun) : support w8a8
782
- assert not use_fp8_w8a8
783
- #if self.use_fp8_w8a8:
784
- # qintermediate_cache2, a2q_scale = _fp8_quantize(
785
- # intermediate_cache2, a2_scale, self.block_shape)
786
-
787
- invoke_moe_batched_triton_kernel (A = intermediate_cache2 ,
788
- B = w2 ,
789
- C = intermediate_cache3 ,
790
- expert_num_tokens = expert_num_tokens ,
791
- compute_type = compute_type ,
792
- A_scale = a2q_scale ,
793
- B_scale = w2_scale ,
794
- B_zp = w2_zp ,
795
- use_fp8_w8a8 = use_fp8_w8a8 ,
796
- use_int8_w8a16 = use_int8_w8a16 ,
797
- use_int4_w4a16 = use_int4_w4a16 ,
798
- config = config ,
799
- block_shape = block_shape )
800
-
801
- return intermediate_cache3
802
-
803
-
804
- def _apply_fake (
805
- hidden_states : torch .Tensor ,
806
- w1 : torch .Tensor ,
807
- w2 : torch .Tensor ,
808
- topk_ids : torch .Tensor ,
809
- activation : str ,
810
- global_num_experts : int ,
811
- expert_map : Optional [torch .Tensor ],
812
- w1_scale : Optional [torch .Tensor ],
813
- w2_scale : Optional [torch .Tensor ],
814
- w1_zp : Optional [torch .Tensor ],
815
- w2_zp : Optional [torch .Tensor ],
816
- a1q_scale : Optional [torch .Tensor ],
817
- a2_scale : Optional [torch .Tensor ],
818
- workspace13 : torch .Tensor ,
819
- workspace2 : torch .Tensor ,
820
- expert_num_tokens : Optional [torch .Tensor ],
821
- use_fp8_w8a8 : bool ,
822
- use_int8_w8a16 : bool ,
823
- use_int4_w4a16 : bool ,
824
- block_shape : Optional [List [int ]],
825
- ) -> torch .Tensor :
826
- return torch .empty_like (hidden_states )
827
-
828
-
829
- direct_register_custom_op (
830
- op_name = "_apply" ,
831
- op_func = _apply ,
832
- mutates_args = [],
833
- fake_impl = _apply_fake ,
834
- tags = (torch .Tag .needs_fixed_stride_order , ),
835
- )
836
-
837
-
838
685
class BatchedTritonExperts (mk .FusedMoEPermuteExpertsUnpermute ):
839
686
840
687
def __init__ (
@@ -845,6 +692,8 @@ def __init__(
845
692
use_int8_w8a16 : bool = False ,
846
693
use_int4_w4a16 : bool = False ,
847
694
block_shape : Optional [List [int ]] = None ,
695
+ world_size : int = 1 ,
696
+ dp_size : int = 1 ,
848
697
):
849
698
super ().__init__ ()
850
699
self .use_fp8_w8a8 = use_fp8_w8a8
@@ -855,6 +704,8 @@ def __init__(
855
704
self .max_num_tokens = max_num_tokens
856
705
assert not use_int8_w8a8 , "NYI"
857
706
assert not use_int4_w4a16 , "NYI"
707
+ self .world_size = world_size
708
+ self .dp_size = dp_size
858
709
859
710
def workspace_shapes (
860
711
self ,
@@ -866,10 +717,11 @@ def workspace_shapes(
866
717
num_experts : int ,
867
718
) -> Tuple [int , int , torch .dtype ]:
868
719
assert a .dim () == 2
720
+ num_dp = self .world_size // self .dp_size
869
721
max_num_tokens = a .shape [
870
722
0 ] if self .max_num_tokens is None else self .max_num_tokens
871
- workspace13 = num_experts * max_num_tokens * max (K , N )
872
- workspace2 = num_experts * max_num_tokens * (N // 2 )
723
+ workspace13 = num_experts * max_num_tokens * num_dp * max (K , N )
724
+ workspace2 = num_experts * max_num_tokens * num_dp * (N // 2 )
873
725
return (workspace13 , workspace2 , a .dtype )
874
726
875
727
def apply (
@@ -891,29 +743,6 @@ def apply(
891
743
workspace2 : torch .Tensor ,
892
744
expert_num_tokens : Optional [torch .Tensor ],
893
745
) -> torch .Tensor :
894
- return torch .ops .vllm ._apply (
895
- hidden_states ,
896
- w1 ,
897
- w2 ,
898
- topk_ids ,
899
- activation ,
900
- global_num_experts ,
901
- expert_map ,
902
- w1_scale ,
903
- w2_scale ,
904
- w1_zp ,
905
- w2_zp ,
906
- a1q_scale ,
907
- a2_scale ,
908
- workspace13 ,
909
- workspace2 ,
910
- expert_num_tokens ,
911
- self .use_fp8_w8a8 ,
912
- self .use_int8_w8a16 ,
913
- self .use_int4_w4a16 ,
914
- self .block_shape ,
915
- )
916
-
917
746
# Check constraints.
918
747
if self .use_int4_w4a16 :
919
748
assert hidden_states .shape [- 1 ] // 2 == w1 .shape [
@@ -988,10 +817,13 @@ def apply(
988
817
block_shape = self .block_shape )
989
818
990
819
# Fix activations
991
- assert activation == "silu"
992
- invoke_batched_silu_and_mul (output = intermediate_cache2 ,
993
- input = intermediate_cache1 ,
994
- expert_num_tokens = expert_num_tokens )
820
+ # assert activation == "silu"
821
+ # invoke_batched_silu_and_mul(output=intermediate_cache2,
822
+ # input=intermediate_cache1,
823
+ # expert_num_tokens=expert_num_tokens)
824
+ self .activation (activation ,
825
+ intermediate_cache2 .view (- 1 , N // 2 ),
826
+ intermediate_cache1 .view (- 1 , N ))
995
827
996
828
#qintermediate_cache2 = intermediate_cache2
997
829
a2q_scale = a2_scale
0 commit comments