@@ -372,8 +372,8 @@ class MoERunner(TunableRunner):
372
372
tuning_config = TuningConfig (
373
373
dynamic_tensor_specs = (
374
374
DynamicTensorSpec (
375
- 0 ,
376
- 0 ,
375
+ ( 0 ,) ,
376
+ ( 0 ,) ,
377
377
get_last_power_of_2_num_tokens_buckets (8192 ),
378
378
lambda x : min (last_positive_power_of_2 (x ), 8192 ),
379
379
),
@@ -946,7 +946,7 @@ class MoERunner(TunableRunner):
946
946
(0 , 0 , 0 , 0 , 0 , 0 ),
947
947
get_last_power_of_2_num_tokens_buckets (8192 ),
948
948
lambda x : min (last_positive_power_of_2 (x ), 8192 ),
949
- # dynamic_tensor_initializers
949
+ dynamic_tensor_initializers ,
950
950
),
951
951
)
952
952
)
@@ -957,7 +957,7 @@ class MoERunner(TunableRunner):
957
957
(0 , 0 , 0 , 0 , 0 ),
958
958
get_last_power_of_2_num_tokens_buckets (8192 ),
959
959
lambda x : min (last_positive_power_of_2 (x ), 8192 ),
960
- # dynamic_tensor_initializers[:5]
960
+ dynamic_tensor_initializers [:5 ],
961
961
),
962
962
),
963
963
)
@@ -1057,6 +1057,19 @@ def forward(
1057
1057
intermediate_size : int ,
1058
1058
num_local_experts : int ,
1059
1059
num_tokens : int ,
1060
+ routing_bias : Optional [torch .Tensor ] = None ,
1061
+ gemm1_bias : Optional [torch .Tensor ] = None ,
1062
+ gemm1_alpha : Optional [torch .Tensor ] = None ,
1063
+ gemm1_beta : Optional [torch .Tensor ] = None ,
1064
+ gemm1_clamp_limit : Optional [torch .Tensor ] = None ,
1065
+ gemm2_bias : Optional [torch .Tensor ] = None ,
1066
+ output1_scale_scalar : Optional [torch .Tensor ] = None ,
1067
+ output1_scale_gate_scalar : Optional [torch .Tensor ] = None ,
1068
+ output2_scale_scalar : Optional [torch .Tensor ] = None ,
1069
+ n_group : Optional [int ] = None ,
1070
+ topk_group : Optional [int ] = None ,
1071
+ local_expert_offset : int = 0 ,
1072
+ routed_scaling_factor : Optional [float ] = None ,
1060
1073
routing_method_type : int = 1 ,
1061
1074
tactic : int = - 1 ,
1062
1075
do_preparation : bool = False ,
@@ -1098,34 +1111,34 @@ def forward(
1098
1111
routing_logits .to (torch .bfloat16 ),
1099
1112
topk_ids ,
1100
1113
expert_weights ,
1101
- None , # routing_bias
1114
+ routing_bias ,
1102
1115
hidden_states ,
1103
- hidden_states_scale . reshape ( - 1 ) , # hidden_states_scale
1116
+ hidden_states_scale , # hidden_states_scale
1104
1117
gemm1_weights ,
1105
1118
gemm1_weights_scale ,
1106
- None , # gemm1_bias
1107
- None , # gemm1_alpha
1108
- None , # gemm1_beta
1109
- None , # gemm1_clamp_limit
1119
+ gemm1_bias ,
1120
+ gemm1_alpha ,
1121
+ gemm1_beta ,
1122
+ gemm1_clamp_limit ,
1110
1123
gemm2_weights ,
1111
1124
gemm2_weights_scale ,
1112
- None , # gemm2_bias
1113
- None , # output1_scale_scalar
1114
- None , # output1_scale_gate_scalar
1115
- None , # output2_scale_scalar
1125
+ gemm2_bias ,
1126
+ output1_scale_scalar ,
1127
+ output1_scale_gate_scalar ,
1128
+ output2_scale_scalar ,
1116
1129
num_local_experts ,
1117
1130
self .top_k ,
1118
- None , # n_group
1119
- None , # topk_group
1131
+ n_group ,
1132
+ topk_group ,
1120
1133
intermediate_size ,
1121
- 0 , # local_expert_offset
1134
+ local_expert_offset ,
1122
1135
num_local_experts ,
1123
- None , # routed_scaling_factor
1124
- tile_tokens_dim , # tile_tokens_dim
1125
- routing_method_type , # routing_method_type
1136
+ routed_scaling_factor ,
1137
+ tile_tokens_dim ,
1138
+ routing_method_type ,
1126
1139
True , # do_finalize
1127
- output , # output
1128
- tactic , # config_idx
1140
+ output ,
1141
+ tactic ,
1129
1142
)
1130
1143
1131
1144
@classmethod
@@ -1138,7 +1151,7 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
1138
1151
(0 , 0 , 0 , 0 , 0 , 0 ),
1139
1152
get_last_power_of_2_num_tokens_buckets (tune_max_num_tokens ),
1140
1153
lambda x : min (last_positive_power_of_2 (x ), tune_max_num_tokens ),
1141
- # cls.dynamic_tensor_initializers
1154
+ cls .dynamic_tensor_initializers ,
1142
1155
),
1143
1156
)
1144
1157
)
@@ -1149,7 +1162,7 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
1149
1162
(0 , 0 , 0 , 0 , 0 ),
1150
1163
get_last_power_of_2_num_tokens_buckets (tune_max_num_tokens ),
1151
1164
lambda x : min (last_positive_power_of_2 (x ), tune_max_num_tokens ),
1152
- # cls.dynamic_tensor_initializers[:5]
1165
+ cls .dynamic_tensor_initializers [:5 ],
1153
1166
),
1154
1167
),
1155
1168
)
@@ -1378,69 +1391,64 @@ def trtllm_fp4_block_scale_moe_op(
1378
1391
)
1379
1392
1380
1393
tuner = AutoTuner .get ()
1381
- if tuner .is_tuning_mode :
1382
- MoERunner .refine_tuning_config (tune_max_num_tokens )
1383
- dtype_act = deduce_trtllm_gen_tensor_dtype (
1384
- hidden_states , hidden_states_scale
1385
- )
1386
- dtype_weights = deduce_trtllm_gen_tensor_dtype (
1387
- gemm1_weights , gemm1_weights_scale
1388
- )
1389
- moe_runner = MoERunner (
1390
- top_k = top_k ,
1391
- num_experts = num_experts ,
1392
- dtype_act = dtype_act ,
1393
- dtype_weights = dtype_weights ,
1394
- use_deepseek_fp8 = False ,
1395
- tile_tokens_dim = tile_tokens_dim ,
1396
- tune_max_num_tokens = tune_max_num_tokens ,
1397
- )
1398
- tunning_config = (
1399
- MoERunner .tuning_config_no_hidden_states_scales
1400
- if hidden_states_scale is None
1401
- else MoERunner .tuning_config_with_hidden_states_scales
1402
- )
1403
- inputs = [
1404
- output ,
1405
- routing_logits ,
1406
- topk_ids ,
1407
- expert_weights ,
1408
- hidden_states ,
1409
- gemm1_weights ,
1410
- gemm2_weights ,
1411
- ]
1412
- # hidden_states_scale should be in front of gemm1_weights_scale and gemm2_weights_scale
1413
- if hidden_states_scale is not None :
1414
- inputs .append (hidden_states_scale )
1415
- inputs .append (gemm1_weights_scale )
1416
- inputs .append (gemm2_weights_scale )
1417
-
1418
- _ , tactic = tuner .choose_one (
1419
- "flashinfer::trtllm_fp4_block_scale_moe" ,
1420
- [moe_runner ],
1421
- tunning_config ,
1422
- inputs ,
1423
- hidden_size = hidden_size ,
1424
- intermediate_size = intermediate_size ,
1425
- num_local_experts = num_experts ,
1426
- num_tokens = num_tokens ,
1427
- routing_method_type = routing_method_type ,
1428
- )
1429
- print (f"tactic: { tactic } " )
1430
- default_tactic = moe_op .trtllm_get_default_moe_configs (
1431
- tile_tokens_dim ,
1432
- dtype_act ,
1433
- dtype_weights ,
1434
- False ,
1435
- top_k ,
1436
- hidden_size ,
1437
- intermediate_size ,
1438
- num_experts ,
1439
- num_tokens ,
1440
- )
1441
- print (f"default_tactic: { default_tactic } " )
1442
- else :
1443
- tactic = - 1
1394
+ MoERunner .refine_tuning_config (tune_max_num_tokens )
1395
+ dtype_act = deduce_trtllm_gen_tensor_dtype (hidden_states , hidden_states_scale )
1396
+ dtype_weights = deduce_trtllm_gen_tensor_dtype (
1397
+ gemm1_weights , gemm1_weights_scale
1398
+ )
1399
+ moe_runner = MoERunner (
1400
+ top_k = top_k ,
1401
+ num_experts = num_experts ,
1402
+ dtype_act = dtype_act ,
1403
+ dtype_weights = dtype_weights ,
1404
+ use_deepseek_fp8 = False ,
1405
+ tile_tokens_dim = tile_tokens_dim ,
1406
+ tune_max_num_tokens = tune_max_num_tokens ,
1407
+ )
1408
+ tunning_config = (
1409
+ MoERunner .tuning_config_no_hidden_states_scales
1410
+ if hidden_states_scale is None
1411
+ else MoERunner .tuning_config_with_hidden_states_scales
1412
+ )
1413
+ inputs = [
1414
+ output ,
1415
+ routing_logits ,
1416
+ topk_ids ,
1417
+ expert_weights ,
1418
+ hidden_states ,
1419
+ gemm1_weights ,
1420
+ gemm2_weights ,
1421
+ ]
1422
+ # hidden_states_scale should be in front of gemm1_weights_scale and gemm2_weights_scale
1423
+ if hidden_states_scale is not None :
1424
+ inputs .append (hidden_states_scale )
1425
+ inputs .append (gemm1_weights_scale )
1426
+ inputs .append (gemm2_weights_scale )
1427
+
1428
+ _ , tactic = tuner .choose_one (
1429
+ "flashinfer::trtllm_fp4_block_scale_moe" ,
1430
+ [moe_runner ],
1431
+ tunning_config ,
1432
+ inputs ,
1433
+ hidden_size = hidden_size ,
1434
+ intermediate_size = intermediate_size ,
1435
+ num_local_experts = num_experts ,
1436
+ num_tokens = num_tokens ,
1437
+ routing_bias = routing_bias ,
1438
+ gemm1_bias = gemm1_bias ,
1439
+ gemm1_alpha = gemm1_alpha ,
1440
+ gemm1_beta = gemm1_beta ,
1441
+ gemm1_clamp_limit = gemm1_clamp_limit ,
1442
+ gemm2_bias = gemm2_bias ,
1443
+ output1_scale_scalar = output1_scale_scalar ,
1444
+ output1_scale_gate_scalar = output1_scale_gate_scalar ,
1445
+ output2_scale_scalar = output2_scale_scalar ,
1446
+ n_group = n_group ,
1447
+ topk_group = topk_group ,
1448
+ local_expert_offset = local_expert_offset ,
1449
+ routed_scaling_factor = routed_scaling_factor ,
1450
+ routing_method_type = routing_method_type ,
1451
+ )
1444
1452
1445
1453
# Call the C++ function for block scale MoE
1446
1454
output = moe_op .trtllm_fp4_block_scale_moe (
@@ -1449,7 +1457,7 @@ def trtllm_fp4_block_scale_moe_op(
1449
1457
expert_weights ,
1450
1458
routing_bias ,
1451
1459
hidden_states ,
1452
- hidden_states_scale . reshape ( - 1 ) ,
1460
+ hidden_states_scale ,
1453
1461
gemm1_weights ,
1454
1462
gemm1_weights_scale ,
1455
1463
gemm1_bias ,
0 commit comments