@@ -395,7 +395,7 @@ def invoke_moe_batched_triton_kernel(
395
395
assert max_num_tokens % BLOCK_M == 0
396
396
397
397
grid = (expert_num_tokens .size (0 ), triton .cdiv (max_num_tokens , BLOCK_M ) *
398
- triton .cdiv (B .shape [ 1 ] , BLOCK_N ))
398
+ triton .cdiv (B .size ( 1 ) , BLOCK_N ))
399
399
400
400
batched_triton_kernel [grid ](
401
401
A ,
@@ -493,17 +493,17 @@ def dispatch(
493
493
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ]]:
494
494
assert a1 .dim () == 2
495
495
assert topk_ids .dim () == 2
496
- assert topk_ids .shape [ 0 ] == a1 .shape [ 0 ]
496
+ assert topk_ids .size ( 0 ) == a1 .size ( 0 )
497
497
498
498
if apply_router_weight_on_input :
499
- topk = topk_ids .shape [ 1 ]
499
+ topk = topk_ids .size ( 1 )
500
500
# TODO: this only works for topK=1, will need to update for topK>1
501
501
assert topk == 1 , \
502
502
"apply_router_weight_on_input is only implemented for topk=1"
503
503
a1 .mul_ (topk_weights .to (a1 .dtype ))
504
504
505
- num_tokens , hidden_dim = a1 .shape
506
- topk = topk_ids .shape [ 1 ]
505
+ num_tokens , hidden_dim = a1 .size ()
506
+ topk = topk_ids .size ( 1 )
507
507
508
508
if self .max_num_tokens is None :
509
509
tokens_per_expert = torch .bincount (topk_ids .view (- 1 ),
@@ -543,10 +543,10 @@ def combine(
543
543
topk_ids : torch .Tensor ,
544
544
apply_router_weight_on_input : bool ,
545
545
) -> None :
546
- num_tokens = topk_ids .shape [ 0 ]
547
- num_local_experts = fused_expert_output .shape [ 0 ]
548
- K = fused_expert_output .shape [ - 1 ]
549
- assert output .shape [ 0 ] == num_tokens and output .shape [ 1 ] == K
546
+ num_tokens = topk_ids .size ( 0 )
547
+ num_local_experts = fused_expert_output .size ( 0 )
548
+ K = fused_expert_output .size ( - 1 )
549
+ assert output .size ( 0 ) == num_tokens and output .size ( 1 ) == K
550
550
551
551
output .fill_ (0 )
552
552
@@ -559,7 +559,7 @@ def combine(
559
559
rows = torch .count_nonzero (topks )
560
560
rhs = fused_expert_output [expert_id - first_expert , :rows , :]
561
561
if not apply_router_weight_on_input :
562
- rhs .mul_ (topk_weights [topkws ].view (rhs .shape [ 0 ] , 1 ))
562
+ rhs .mul_ (topk_weights [topkws ].view (rhs .size ( 0 ) , 1 ))
563
563
output [topks ] = output [topks ] + rhs
564
564
565
565
@@ -599,8 +599,8 @@ def workspace_shapes(
599
599
) -> Tuple [int , int , torch .dtype ]:
600
600
assert a .dim () == 2
601
601
num_dp = self .world_size // self .dp_size
602
- max_num_tokens = a .shape [
603
- 0 ] if self .max_num_tokens is None else self .max_num_tokens
602
+ max_num_tokens = a .size (
603
+ 0 ) if self .max_num_tokens is None else self .max_num_tokens
604
604
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
605
605
workspace13 = num_experts * max_num_tokens * num_dp * K
606
606
workspace2 = max_num_tokens * num_dp * N
@@ -627,27 +627,27 @@ def apply(
627
627
) -> torch .Tensor :
628
628
assert hidden_states .dim () == 3
629
629
assert expert_num_tokens is not None
630
- hidden_dim = hidden_states .shape [ - 1 ]
630
+ hidden_dim = hidden_states .size ( - 1 )
631
631
632
632
if self .max_num_tokens is None :
633
- max_num_tokens = hidden_states .shape [ 1 ]
633
+ max_num_tokens = hidden_states .size ( 1 )
634
634
else :
635
635
max_num_tokens = self .max_num_tokens
636
636
637
637
num_dp = self .world_size // self .dp_size
638
638
num_experts = global_num_experts
639
639
out = _resize_cache (workspace13 ,
640
640
(num_experts , max_num_tokens * num_dp , hidden_dim ))
641
- num_local_experts = w1 .shape [ 0 ]
642
- assert num_local_experts == w1 .shape [
643
- 0 ], f"{ num_local_experts } == { w1 .shape [ 0 ] } "
641
+ num_local_experts = w1 .size ( 0 )
642
+ assert num_local_experts == w1 .size ( 0 ), (
643
+ f"{ num_local_experts } == { w1 .size ( 0 ) } " )
644
644
645
- N = w1 .shape [ 1 ] // 2
645
+ N = w1 .size ( 1 ) // 2
646
646
647
647
# Not cudagraph friendly
648
648
assert (torch .cuda .is_current_stream_capturing ()
649
- or torch .all (expert_num_tokens <= max_num_tokens )), (
650
- f"{ expert_num_tokens } <= { max_num_tokens } " )
649
+ or torch .all (expert_num_tokens <= max_num_tokens * num_dp )), (
650
+ f"{ expert_num_tokens } <= { max_num_tokens * num_dp } " )
651
651
652
652
for expert in range (num_local_experts ):
653
653
# Indexing expert_num_tokens doesn't work w/cudagraphs
@@ -699,8 +699,8 @@ def workspace_shapes(
699
699
) -> Tuple [int , int , torch .dtype ]:
700
700
assert a .dim () == 2
701
701
num_dp = self .world_size // self .dp_size
702
- max_num_tokens = a .shape [
703
- 0 ] if self .max_num_tokens is None else self .max_num_tokens
702
+ max_num_tokens = a .size (
703
+ 0 ) if self .max_num_tokens is None else self .max_num_tokens
704
704
workspace13 = num_experts * max_num_tokens * num_dp * max (K , N )
705
705
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2 )
706
706
return (workspace13 , workspace2 , a .dtype )
@@ -726,12 +726,12 @@ def apply(
726
726
) -> torch .Tensor :
727
727
# Check constraints.
728
728
if self .use_int4_w4a16 :
729
- assert hidden_states .shape [ - 1 ] // 2 == w1 .shape [
730
- 2 ], "Hidden size mismatch"
729
+ assert hidden_states .size ( - 1 ) // 2 == w1 .size ( 2 ), (
730
+ "Hidden size mismatch" )
731
731
else :
732
- assert hidden_states .shape [ - 1 ] == w1 .shape [ 2 ], \
733
- ( f"Hidden size mismatch { hidden_states .shape [ - 1 ] } "
734
- f"!= { w1 .shape [ 2 ] } " )
732
+ assert hidden_states .size ( - 1 ) == w1 .size ( 2 ), (
733
+ f"Hidden size mismatch { hidden_states .size ( - 1 ) } "
734
+ f"!= { w1 .size ( 2 ) } " )
735
735
736
736
assert hidden_states .is_contiguous (
737
737
), "Hidden_states must be contiguous"
@@ -745,17 +745,17 @@ def apply(
745
745
E , num_tokens , N , K , top_k_num = mk ._moe_problem_size (
746
746
hidden_states , w1 , w2 , topk_ids )
747
747
748
- assert w1 .shape [ 0 ] == E
749
- assert w2 .shape [ 0 ] == E
748
+ assert w1 .size ( 0 ) == E
749
+ assert w2 .size ( 0 ) == E
750
750
751
751
config_dtype = get_config_dtype_str (use_fp8_w8a8 = self .use_fp8_w8a8 ,
752
752
use_int8_w8a16 = self .use_int8_w8a16 ,
753
753
use_int4_w4a16 = self .use_int4_w4a16 ,
754
754
dtype = hidden_states .dtype )
755
755
756
756
config = try_get_optimal_moe_config (
757
- w1 .shape ,
758
- w2 .shape ,
757
+ w1 .size () ,
758
+ w2 .size () ,
759
759
top_k_num ,
760
760
config_dtype ,
761
761
num_tokens ,
@@ -797,13 +797,13 @@ def apply(
797
797
config = config ,
798
798
block_shape = self .block_shape )
799
799
800
- # Fix activations
801
- if True :
802
- assert activation == "silu"
800
+ if activation == "silu" :
803
801
invoke_batched_silu_and_mul (output = intermediate_cache2 ,
804
802
input = intermediate_cache1 ,
805
803
expert_num_tokens = expert_num_tokens )
806
804
else :
805
+ # TODO: would be nice to use expert_num_tokens here to reduce
806
+ # garbage compute
807
807
self .activation (activation , intermediate_cache2 .view (- 1 , N // 2 ),
808
808
intermediate_cache1 .view (- 1 , N ))
809
809
0 commit comments