@@ -921,28 +921,28 @@ class MoERunner(TunableRunner):
921
921
dynamic_tensor_initializers = [
922
922
lambda shapes , dtype , device : torch .empty (
923
923
shapes , device = device , dtype = dtype
924
- ), # output buffer
924
+ ), # output buffer, [num_tokens, hidden_size]
925
925
lambda shapes , dtype , device : torch .rand (
926
926
shapes , device = device , dtype = dtype
927
- ), # routing_logits
927
+ ), # routing_logits, [num_tokens, num_experts]
928
928
lambda shapes , dtype , device : torch .empty (
929
929
shapes , device = device , dtype = dtype
930
- ), # topk_ids buffer. empty since routing_logits is used
930
+ ), # topk_ids buffer. empty since routing_logits is used. [num_tokens, topk]
931
931
lambda shapes , dtype , device : torch .empty (
932
932
shapes , device = device , dtype = dtype
933
- ), # expert_weights buffer. empty since routing_logits is used
933
+ ), # expert_weights buffer. empty since routing_logits is used. [num_tokens, topk]
934
934
lambda shapes , dtype , device : torch .randn (shapes , device = device ).to (
935
935
dtype
936
- ), # hidden_states
936
+ ), # hidden_states, [num_tokens, hidden_size]
937
937
lambda shapes , dtype , device : torch .ones (shapes , device = device ).to (
938
938
dtype
939
- ), # hidden_states_scale
939
+ ), # hidden_states_scale, [num_tokens, hidden_size // sf_vec_size]
940
940
]
941
941
# their first dimension is num_tokens which will be tuned
942
942
tuning_config_with_hidden_states_scales = TuningConfig (
943
943
dynamic_tensor_specs = (
944
944
DynamicTensorSpec (
945
- (0 , 1 , 2 , 3 , 4 , 7 ),
945
+ (0 , 1 , 2 , 3 , 4 , 5 ),
946
946
(0 , 0 , 0 , 0 , 0 , 0 ),
947
947
get_last_power_of_2_num_tokens_buckets (8192 ),
948
948
lambda x : min (last_positive_power_of_2 (x ), 8192 ),
@@ -972,6 +972,8 @@ def __init__(
972
972
dtype_act : DtypeTrtllmGen ,
973
973
dtype_weights : DtypeTrtllmGen ,
974
974
use_deepseek_fp8 : bool ,
975
+ hidden_size : int ,
976
+ intermediate_size : int ,
975
977
tile_tokens_dim : Optional [int ] = None ,
976
978
tune_max_num_tokens : int = 8192 ,
977
979
):
@@ -981,6 +983,8 @@ def __init__(
981
983
self .dtype_weights = dtype_weights
982
984
self .use_deepseek_fp8 = use_deepseek_fp8
983
985
self .top_k = top_k
986
+ self .hidden_size = hidden_size
987
+ self .intermediate_size = intermediate_size
984
988
self .tile_tokens_dim = tile_tokens_dim
985
989
986
990
def get_tile_tokens_dim (self , num_tokens : int , top_k : int ):
@@ -1016,17 +1020,8 @@ def get_valid_tactics(
1016
1020
topk_ids ,
1017
1021
expert_weights ,
1018
1022
hidden_states ,
1019
- gemm1_weights ,
1020
- gemm2_weights ,
1021
1023
* extra_inputs ,
1022
1024
) = inputs
1023
- hidden_size = hidden_states .shape [1 ]
1024
- if (
1025
- self .dtype_act == DtypeTrtllmGen .E2m1
1026
- or self .dtype_act == DtypeTrtllmGen .MxE2m1
1027
- ): # packed into uint8
1028
- hidden_size *= 2
1029
- intermediate_size = gemm1_weights .shape [1 ] // 2
1030
1025
num_tokens = routing_logits .shape [0 ]
1031
1026
tile_tokens_dim = (
1032
1027
self .get_tile_tokens_dim (num_tokens , self .top_k )
@@ -1039,8 +1034,8 @@ def get_valid_tactics(
1039
1034
self .dtype_weights ,
1040
1035
self .use_deepseek_fp8 ,
1041
1036
self .top_k ,
1042
- hidden_size ,
1043
- intermediate_size ,
1037
+ self . hidden_size ,
1038
+ self . intermediate_size ,
1044
1039
self .num_experts ,
1045
1040
num_tokens ,
1046
1041
)
@@ -1053,24 +1048,25 @@ def get_valid_tactics(
1053
1048
def forward (
1054
1049
self ,
1055
1050
inputs : List [torch .Tensor ],
1056
- hidden_size : int ,
1057
- intermediate_size : int ,
1058
1051
num_local_experts : int ,
1059
- num_tokens : int ,
1060
- routing_bias : Optional [torch .Tensor ] = None ,
1061
- gemm1_bias : Optional [torch .Tensor ] = None ,
1062
- gemm1_alpha : Optional [torch .Tensor ] = None ,
1063
- gemm1_beta : Optional [torch .Tensor ] = None ,
1064
- gemm1_clamp_limit : Optional [torch .Tensor ] = None ,
1065
- gemm2_bias : Optional [torch .Tensor ] = None ,
1066
- output1_scale_scalar : Optional [torch .Tensor ] = None ,
1067
- output1_scale_gate_scalar : Optional [torch .Tensor ] = None ,
1068
- output2_scale_scalar : Optional [torch .Tensor ] = None ,
1069
- n_group : Optional [int ] = None ,
1070
- topk_group : Optional [int ] = None ,
1071
- local_expert_offset : int = 0 ,
1072
- routed_scaling_factor : Optional [float ] = None ,
1073
- routing_method_type : int = 1 ,
1052
+ routing_bias : Optional [torch .Tensor ],
1053
+ gemm1_weights : torch .Tensor ,
1054
+ gemm1_weights_scale : Optional [torch .Tensor ],
1055
+ gemm1_bias : Optional [torch .Tensor ],
1056
+ gemm1_alpha : Optional [torch .Tensor ],
1057
+ gemm1_beta : Optional [torch .Tensor ],
1058
+ gemm1_clamp_limit : Optional [torch .Tensor ],
1059
+ gemm2_weights : torch .Tensor ,
1060
+ gemm2_weights_scale : Optional [torch .Tensor ],
1061
+ gemm2_bias : Optional [torch .Tensor ],
1062
+ output1_scale_scalar : Optional [torch .Tensor ],
1063
+ output1_scale_gate_scalar : Optional [torch .Tensor ],
1064
+ output2_scale_scalar : Optional [torch .Tensor ],
1065
+ n_group : Optional [int ],
1066
+ topk_group : Optional [int ],
1067
+ local_expert_offset : int ,
1068
+ routed_scaling_factor : Optional [float ],
1069
+ routing_method_type : int ,
1074
1070
tactic : int = - 1 ,
1075
1071
do_preparation : bool = False ,
1076
1072
):
@@ -1080,10 +1076,9 @@ def forward(
1080
1076
topk_ids ,
1081
1077
expert_weights ,
1082
1078
hidden_states ,
1083
- gemm1_weights ,
1084
- gemm2_weights ,
1085
1079
* extra_inputs ,
1086
1080
) = inputs
1081
+ num_tokens = routing_logits .shape [0 ]
1087
1082
tile_tokens_dim = (
1088
1083
self .get_tile_tokens_dim (num_tokens , self .top_k )
1089
1084
if self .tile_tokens_dim is None
@@ -1092,19 +1087,27 @@ def forward(
1092
1087
1093
1088
extra_input_idx = 0
1094
1089
if trtllm_gen_dtype_has_scale (self .dtype_act ):
1095
- hidden_states_scale = (
1096
- extra_inputs [extra_input_idx ].view (torch .float8_e4m3fn ).reshape (- 1 )
1097
- )
1090
+ hidden_states_scale = extra_inputs [extra_input_idx ]
1098
1091
extra_input_idx += 1
1099
1092
else :
1100
1093
hidden_states_scale = None
1101
- if trtllm_gen_dtype_has_scale (self .dtype_weights ):
1102
- gemm1_weights_scale = extra_inputs [extra_input_idx ]
1103
- gemm2_weights_scale = extra_inputs [extra_input_idx + 1 ]
1104
- extra_input_idx += 2
1105
- else :
1106
- gemm1_weights_scale = None
1107
- gemm2_weights_scale = None
1094
+ # sanity checks to ensure that dynamic tensors have the correct shapes
1095
+ assert output .shape [0 ] == num_tokens , (
1096
+ "output's first dimension must be batch size."
1097
+ )
1098
+ assert topk_ids .shape [0 ] == num_tokens , (
1099
+ "topk_ids's first dimension must be batch size."
1100
+ )
1101
+ assert expert_weights .shape [0 ] == num_tokens , (
1102
+ "expert_weights's first dimension must be batch size."
1103
+ )
1104
+ assert hidden_states .shape [0 ] == num_tokens , (
1105
+ "hidden_states's first dimension must be batch size."
1106
+ )
1107
+ assert (
1108
+ hidden_states_scale is None
1109
+ or hidden_states_scale .shape [0 ] == num_tokens
1110
+ ), "hidden_states_scale's first dimension must be batch size"
1108
1111
1109
1112
# TODO(siyuan): support fp8
1110
1113
moe_op .trtllm_fp4_block_scale_moe (
@@ -1126,11 +1129,11 @@ def forward(
1126
1129
output1_scale_scalar ,
1127
1130
output1_scale_gate_scalar ,
1128
1131
output2_scale_scalar ,
1129
- num_local_experts ,
1132
+ self . num_experts ,
1130
1133
self .top_k ,
1131
1134
n_group ,
1132
1135
topk_group ,
1133
- intermediate_size ,
1136
+ self . intermediate_size ,
1134
1137
local_expert_offset ,
1135
1138
num_local_experts ,
1136
1139
routed_scaling_factor ,
@@ -1147,7 +1150,7 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
1147
1150
cls .tuning_config_with_hidden_states_scales = TuningConfig (
1148
1151
dynamic_tensor_specs = (
1149
1152
DynamicTensorSpec (
1150
- (0 , 1 , 2 , 3 , 4 , 7 ),
1153
+ (0 , 1 , 2 , 3 , 4 , 5 ),
1151
1154
(0 , 0 , 0 , 0 , 0 , 0 ),
1152
1155
get_last_power_of_2_num_tokens_buckets (tune_max_num_tokens ),
1153
1156
lambda x : min (last_positive_power_of_2 (x ), tune_max_num_tokens ),
@@ -1402,6 +1405,8 @@ def trtllm_fp4_block_scale_moe_op(
1402
1405
dtype_act = dtype_act ,
1403
1406
dtype_weights = dtype_weights ,
1404
1407
use_deepseek_fp8 = False ,
1408
+ hidden_size = hidden_size ,
1409
+ intermediate_size = intermediate_size ,
1405
1410
tile_tokens_dim = tile_tokens_dim ,
1406
1411
tune_max_num_tokens = tune_max_num_tokens ,
1407
1412
)
@@ -1416,29 +1421,25 @@ def trtllm_fp4_block_scale_moe_op(
1416
1421
topk_ids ,
1417
1422
expert_weights ,
1418
1423
hidden_states ,
1419
- gemm1_weights ,
1420
- gemm2_weights ,
1421
1424
]
1422
- # hidden_states_scale should be in front of gemm1_weights_scale and gemm2_weights_scale
1423
1425
if hidden_states_scale is not None :
1424
1426
inputs .append (hidden_states_scale )
1425
- inputs .append (gemm1_weights_scale )
1426
- inputs .append (gemm2_weights_scale )
1427
1427
1428
1428
_ , tactic = tuner .choose_one (
1429
1429
"flashinfer::trtllm_fp4_block_scale_moe" ,
1430
1430
[moe_runner ],
1431
1431
tunning_config ,
1432
1432
inputs ,
1433
- hidden_size = hidden_size ,
1434
- intermediate_size = intermediate_size ,
1435
1433
num_local_experts = num_experts ,
1436
- num_tokens = num_tokens ,
1437
1434
routing_bias = routing_bias ,
1435
+ gemm1_weights = gemm1_weights ,
1436
+ gemm1_weights_scale = gemm1_weights_scale ,
1438
1437
gemm1_bias = gemm1_bias ,
1439
1438
gemm1_alpha = gemm1_alpha ,
1440
1439
gemm1_beta = gemm1_beta ,
1441
1440
gemm1_clamp_limit = gemm1_clamp_limit ,
1441
+ gemm2_weights = gemm2_weights ,
1442
+ gemm2_weights_scale = gemm2_weights_scale ,
1442
1443
gemm2_bias = gemm2_bias ,
1443
1444
output1_scale_scalar = output1_scale_scalar ,
1444
1445
output1_scale_gate_scalar = output1_scale_gate_scalar ,
0 commit comments