Skip to content

Commit c4270bb

Browse files
committed
upd
1 parent 82bca24 commit c4270bb

File tree

2 files changed

+61
-59
lines changed

2 files changed

+61
-59
lines changed

flashinfer/fused_moe/core.py

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -921,28 +921,28 @@ class MoERunner(TunableRunner):
921921
dynamic_tensor_initializers = [
922922
lambda shapes, dtype, device: torch.empty(
923923
shapes, device=device, dtype=dtype
924-
), # output buffer
924+
), # output buffer, [num_tokens, hidden_size]
925925
lambda shapes, dtype, device: torch.rand(
926926
shapes, device=device, dtype=dtype
927-
), # routing_logits
927+
), # routing_logits, [num_tokens, num_experts]
928928
lambda shapes, dtype, device: torch.empty(
929929
shapes, device=device, dtype=dtype
930-
), # topk_ids buffer. empty since routing_logits is used
930+
), # topk_ids buffer. empty since routing_logits is used. [num_tokens, topk]
931931
lambda shapes, dtype, device: torch.empty(
932932
shapes, device=device, dtype=dtype
933-
), # expert_weights buffer. empty since routing_logits is used
933+
), # expert_weights buffer. empty since routing_logits is used. [num_tokens, topk]
934934
lambda shapes, dtype, device: torch.randn(shapes, device=device).to(
935935
dtype
936-
), # hidden_states
936+
), # hidden_states, [num_tokens, hidden_size]
937937
lambda shapes, dtype, device: torch.ones(shapes, device=device).to(
938938
dtype
939-
), # hidden_states_scale
939+
), # hidden_states_scale, [num_tokens, hidden_size // sf_vec_size]
940940
]
941941
# their first dimension is num_tokens which will be tuned
942942
tuning_config_with_hidden_states_scales = TuningConfig(
943943
dynamic_tensor_specs=(
944944
DynamicTensorSpec(
945-
(0, 1, 2, 3, 4, 7),
945+
(0, 1, 2, 3, 4, 5),
946946
(0, 0, 0, 0, 0, 0),
947947
get_last_power_of_2_num_tokens_buckets(8192),
948948
lambda x: min(last_positive_power_of_2(x), 8192),
@@ -972,6 +972,8 @@ def __init__(
972972
dtype_act: DtypeTrtllmGen,
973973
dtype_weights: DtypeTrtllmGen,
974974
use_deepseek_fp8: bool,
975+
hidden_size: int,
976+
intermediate_size: int,
975977
tile_tokens_dim: Optional[int] = None,
976978
tune_max_num_tokens: int = 8192,
977979
):
@@ -981,6 +983,8 @@ def __init__(
981983
self.dtype_weights = dtype_weights
982984
self.use_deepseek_fp8 = use_deepseek_fp8
983985
self.top_k = top_k
986+
self.hidden_size = hidden_size
987+
self.intermediate_size = intermediate_size
984988
self.tile_tokens_dim = tile_tokens_dim
985989

986990
def get_tile_tokens_dim(self, num_tokens: int, top_k: int):
@@ -1016,17 +1020,8 @@ def get_valid_tactics(
10161020
topk_ids,
10171021
expert_weights,
10181022
hidden_states,
1019-
gemm1_weights,
1020-
gemm2_weights,
10211023
*extra_inputs,
10221024
) = inputs
1023-
hidden_size = hidden_states.shape[1]
1024-
if (
1025-
self.dtype_act == DtypeTrtllmGen.E2m1
1026-
or self.dtype_act == DtypeTrtllmGen.MxE2m1
1027-
): # packed into uint8
1028-
hidden_size *= 2
1029-
intermediate_size = gemm1_weights.shape[1] // 2
10301025
num_tokens = routing_logits.shape[0]
10311026
tile_tokens_dim = (
10321027
self.get_tile_tokens_dim(num_tokens, self.top_k)
@@ -1039,8 +1034,8 @@ def get_valid_tactics(
10391034
self.dtype_weights,
10401035
self.use_deepseek_fp8,
10411036
self.top_k,
1042-
hidden_size,
1043-
intermediate_size,
1037+
self.hidden_size,
1038+
self.intermediate_size,
10441039
self.num_experts,
10451040
num_tokens,
10461041
)
@@ -1053,24 +1048,25 @@ def get_valid_tactics(
10531048
def forward(
10541049
self,
10551050
inputs: List[torch.Tensor],
1056-
hidden_size: int,
1057-
intermediate_size: int,
10581051
num_local_experts: int,
1059-
num_tokens: int,
1060-
routing_bias: Optional[torch.Tensor] = None,
1061-
gemm1_bias: Optional[torch.Tensor] = None,
1062-
gemm1_alpha: Optional[torch.Tensor] = None,
1063-
gemm1_beta: Optional[torch.Tensor] = None,
1064-
gemm1_clamp_limit: Optional[torch.Tensor] = None,
1065-
gemm2_bias: Optional[torch.Tensor] = None,
1066-
output1_scale_scalar: Optional[torch.Tensor] = None,
1067-
output1_scale_gate_scalar: Optional[torch.Tensor] = None,
1068-
output2_scale_scalar: Optional[torch.Tensor] = None,
1069-
n_group: Optional[int] = None,
1070-
topk_group: Optional[int] = None,
1071-
local_expert_offset: int = 0,
1072-
routed_scaling_factor: Optional[float] = None,
1073-
routing_method_type: int = 1,
1052+
routing_bias: Optional[torch.Tensor],
1053+
gemm1_weights: torch.Tensor,
1054+
gemm1_weights_scale: Optional[torch.Tensor],
1055+
gemm1_bias: Optional[torch.Tensor],
1056+
gemm1_alpha: Optional[torch.Tensor],
1057+
gemm1_beta: Optional[torch.Tensor],
1058+
gemm1_clamp_limit: Optional[torch.Tensor],
1059+
gemm2_weights: torch.Tensor,
1060+
gemm2_weights_scale: Optional[torch.Tensor],
1061+
gemm2_bias: Optional[torch.Tensor],
1062+
output1_scale_scalar: Optional[torch.Tensor],
1063+
output1_scale_gate_scalar: Optional[torch.Tensor],
1064+
output2_scale_scalar: Optional[torch.Tensor],
1065+
n_group: Optional[int],
1066+
topk_group: Optional[int],
1067+
local_expert_offset: int,
1068+
routed_scaling_factor: Optional[float],
1069+
routing_method_type: int,
10741070
tactic: int = -1,
10751071
do_preparation: bool = False,
10761072
):
@@ -1080,10 +1076,9 @@ def forward(
10801076
topk_ids,
10811077
expert_weights,
10821078
hidden_states,
1083-
gemm1_weights,
1084-
gemm2_weights,
10851079
*extra_inputs,
10861080
) = inputs
1081+
num_tokens = routing_logits.shape[0]
10871082
tile_tokens_dim = (
10881083
self.get_tile_tokens_dim(num_tokens, self.top_k)
10891084
if self.tile_tokens_dim is None
@@ -1092,19 +1087,27 @@ def forward(
10921087

10931088
extra_input_idx = 0
10941089
if trtllm_gen_dtype_has_scale(self.dtype_act):
1095-
hidden_states_scale = (
1096-
extra_inputs[extra_input_idx].view(torch.float8_e4m3fn).reshape(-1)
1097-
)
1090+
hidden_states_scale = extra_inputs[extra_input_idx]
10981091
extra_input_idx += 1
10991092
else:
11001093
hidden_states_scale = None
1101-
if trtllm_gen_dtype_has_scale(self.dtype_weights):
1102-
gemm1_weights_scale = extra_inputs[extra_input_idx]
1103-
gemm2_weights_scale = extra_inputs[extra_input_idx + 1]
1104-
extra_input_idx += 2
1105-
else:
1106-
gemm1_weights_scale = None
1107-
gemm2_weights_scale = None
1094+
# sanity checks to ensure that dynamic tensors have the correct shapes
1095+
assert output.shape[0] == num_tokens, (
1096+
"output's first dimension must be batch size."
1097+
)
1098+
assert topk_ids.shape[0] == num_tokens, (
1099+
"topk_ids's first dimension must be batch size."
1100+
)
1101+
assert expert_weights.shape[0] == num_tokens, (
1102+
"expert_weights's first dimension must be batch size."
1103+
)
1104+
assert hidden_states.shape[0] == num_tokens, (
1105+
"hidden_states's first dimension must be batch size."
1106+
)
1107+
assert (
1108+
hidden_states_scale is None
1109+
or hidden_states_scale.shape[0] == num_tokens
1110+
), "hidden_states_scale's first dimension must be batch size"
11081111

11091112
# TODO(siyuan): support fp8
11101113
moe_op.trtllm_fp4_block_scale_moe(
@@ -1126,11 +1129,11 @@ def forward(
11261129
output1_scale_scalar,
11271130
output1_scale_gate_scalar,
11281131
output2_scale_scalar,
1129-
num_local_experts,
1132+
self.num_experts,
11301133
self.top_k,
11311134
n_group,
11321135
topk_group,
1133-
intermediate_size,
1136+
self.intermediate_size,
11341137
local_expert_offset,
11351138
num_local_experts,
11361139
routed_scaling_factor,
@@ -1147,7 +1150,7 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
11471150
cls.tuning_config_with_hidden_states_scales = TuningConfig(
11481151
dynamic_tensor_specs=(
11491152
DynamicTensorSpec(
1150-
(0, 1, 2, 3, 4, 7),
1153+
(0, 1, 2, 3, 4, 5),
11511154
(0, 0, 0, 0, 0, 0),
11521155
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens),
11531156
lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
@@ -1402,6 +1405,8 @@ def trtllm_fp4_block_scale_moe_op(
14021405
dtype_act=dtype_act,
14031406
dtype_weights=dtype_weights,
14041407
use_deepseek_fp8=False,
1408+
hidden_size=hidden_size,
1409+
intermediate_size=intermediate_size,
14051410
tile_tokens_dim=tile_tokens_dim,
14061411
tune_max_num_tokens=tune_max_num_tokens,
14071412
)
@@ -1416,29 +1421,25 @@ def trtllm_fp4_block_scale_moe_op(
14161421
topk_ids,
14171422
expert_weights,
14181423
hidden_states,
1419-
gemm1_weights,
1420-
gemm2_weights,
14211424
]
1422-
# hidden_states_scale should be in front of gemm1_weights_scale and gemm2_weights_scale
14231425
if hidden_states_scale is not None:
14241426
inputs.append(hidden_states_scale)
1425-
inputs.append(gemm1_weights_scale)
1426-
inputs.append(gemm2_weights_scale)
14271427

14281428
_, tactic = tuner.choose_one(
14291429
"flashinfer::trtllm_fp4_block_scale_moe",
14301430
[moe_runner],
14311431
tunning_config,
14321432
inputs,
1433-
hidden_size=hidden_size,
1434-
intermediate_size=intermediate_size,
14351433
num_local_experts=num_experts,
1436-
num_tokens=num_tokens,
14371434
routing_bias=routing_bias,
1435+
gemm1_weights=gemm1_weights,
1436+
gemm1_weights_scale=gemm1_weights_scale,
14381437
gemm1_bias=gemm1_bias,
14391438
gemm1_alpha=gemm1_alpha,
14401439
gemm1_beta=gemm1_beta,
14411440
gemm1_clamp_limit=gemm1_clamp_limit,
1441+
gemm2_weights=gemm2_weights,
1442+
gemm2_weights_scale=gemm2_weights_scale,
14421443
gemm2_bias=gemm2_bias,
14431444
output1_scale_scalar=output1_scale_scalar,
14441445
output1_scale_gate_scalar=output1_scale_gate_scalar,

include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ struct TopKRedType {
5151
static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) {
5252
auto valueBits = cub::Traits<TypeExpW>::TwiddleIn(
5353
reinterpret_cast<typename cub::Traits<TypeExpW>::UnsignedBits&>(val));
54-
TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits);
54+
TypeCmp compactTmp;
55+
memcpy(&compactTmp, &valueBits, sizeof(valueBits));
5556
compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx));
5657
// Use 65535 minus idx to give higher priority to elements with smaller indices.
5758
return compactTmp;

0 commit comments

Comments
 (0)