44
44
device_support_pdl ,
45
45
get_shuffle_matrix_a_row_indices ,
46
46
get_shuffle_matrix_sf_a_row_indices ,
47
+ calculate_tile_tokens_dim ,
47
48
register_custom_op ,
48
49
register_fake_op ,
49
50
)
50
51
from .utils import (
51
52
get_last_power_of_2_num_tokens_buckets ,
52
53
last_positive_power_of_2 ,
53
- next_positive_power_of_2 ,
54
54
)
55
55
56
56
@@ -894,30 +894,6 @@ def __init__(
894
894
self .gated_act_type = gated_act_type
895
895
self .tile_tokens_dim = tile_tokens_dim
896
896
897
- def get_tile_tokens_dim (
898
- self , num_tokens : int , top_k : int , max_tile_tokens_dim : int = 128
899
- ):
900
- # Factor to account for the imbalance of the experts.
901
- # factor equals to the
902
- # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
903
- # - 1.0 means perfect expert distribution.
904
- # - > 1.0 means some experts have more
905
- # tokens than the perfect distribution.
906
- # - < 1.0 does not make sense.
907
- imbalance_factor = 1.3
908
- # Calculate the number of tokens per expert
909
- # assuming perfect distribution.
910
- num_tokens_per_expert = (num_tokens * top_k ) // self .num_local_experts
911
- # Apply the imbalance factor.
912
- num_tokens_per_expert = int (num_tokens_per_expert * imbalance_factor )
913
- # And pad the number to the next power of 2.
914
- tile_tokens_dim = next_positive_power_of_2 (num_tokens_per_expert )
915
- if num_tokens_per_expert > 128 and num_tokens_per_expert < 256 :
916
- tile_tokens_dim = 192
917
- # Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
918
- tile_tokens_dim = min (max (tile_tokens_dim , 8 ), max_tile_tokens_dim )
919
- return tile_tokens_dim
920
-
921
897
def get_valid_tactics (
922
898
self ,
923
899
inputs : List [torch .Tensor ],
@@ -933,7 +909,12 @@ def get_valid_tactics(
933
909
) = inputs
934
910
num_tokens = routing_logits .shape [0 ]
935
911
tile_tokens_dim = (
936
- self .get_tile_tokens_dim (num_tokens , self .top_k , 128 )
912
+ calculate_tile_tokens_dim (
913
+ num_tokens ,
914
+ self .num_local_experts ,
915
+ self .top_k ,
916
+ 64 if self .dtype_act == DtypeTrtllmGen .Bfloat16 else 128 ,
917
+ )
937
918
if self .tile_tokens_dim is None
938
919
else self .tile_tokens_dim
939
920
)
@@ -977,7 +958,12 @@ def forward(
977
958
) = inputs
978
959
num_tokens = routing_logits .shape [0 ]
979
960
tile_tokens_dim = (
980
- self .get_tile_tokens_dim (num_tokens , self .top_k , 128 )
961
+ calculate_tile_tokens_dim (
962
+ num_tokens ,
963
+ self .num_local_experts ,
964
+ self .top_k ,
965
+ 64 if self .dtype_act == DtypeTrtllmGen .Bfloat16 else 128 ,
966
+ )
981
967
if self .tile_tokens_dim is None
982
968
else self .tile_tokens_dim
983
969
)
@@ -1005,7 +991,6 @@ def forward(
1005
991
hidden_states_scale .dim () == 2
1006
992
and hidden_states_scale .shape [0 ] == num_tokens
1007
993
), "hidden_states_scale's first dimension must be batch size"
1008
-
1009
994
# TODO(siyuan): support fp8
1010
995
moe_op .trtllm_fp4_block_scale_moe (
1011
996
routing_logits ,
0 commit comments