Skip to content

Commit 82bca24

Browse files
committed
UPD
1 parent f590bf2 commit 82bca24

File tree

4 files changed

+152
-118
lines changed

4 files changed

+152
-118
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -876,11 +876,10 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe_launcher(
876876
TORCH_CHECK(hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
877877
"hidden_states_scale must be fp8.");
878878

879-
TORCH_CHECK(hidden_states_scale.value().dim() == 1, "hidden_states_scale must be 1D.");
880-
TORCH_CHECK(hidden_states_scale.value().sizes()[0] ==
881-
tensorrt_llm::computeFP4LinearLayoutSFSize(args.num_tokens,
882-
args.hidden_size / sf_vec_size),
883-
"hidden_states_scale has incorrect size");
879+
TORCH_CHECK(
880+
hidden_states_scale.value().numel() == tensorrt_llm::computeFP4LinearLayoutSFSize(
881+
args.num_tokens, args.hidden_size / sf_vec_size),
882+
"hidden_states_scale has incorrect size");
884883
}
885884

886885
TORCH_CHECK(gemm1_weights.scalar_type() == torch_ext::FLOAT4_E2M1X2,
@@ -1079,7 +1078,7 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
10791078
auto mDtypeAct = btg::Dtype::Bfloat16;
10801079
if (hidden_states.scalar_type() == torch_ext::FLOAT4_E2M1X2) {
10811080
TORCH_CHECK(hidden_states_scale.has_value() &&
1082-
hidden_states_scale.value().scalar_type() == at::ScalarType::Byte,
1081+
hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
10831082
"hidden_states_scale must be provided for fp4 activation.");
10841083
if (hidden_states_scale_vec_size == 16) {
10851084
mDtypeAct = btg::Dtype::E2m1;

flashinfer/autotuner.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Any, Callable, Dict, List, Set, Tuple, Union, Optional
1111

1212
import torch
13-
from tqdm import tqdm
1413

1514
# from tensorrt_llm.bindings.internal.runtime import delay_kernel
1615
# from tensorrt_llm.logger import logger
@@ -38,7 +37,7 @@ def get_config_path(is_module: bool):
3837
)
3938

4039

41-
@dataclass(slots=True, unsafe_hash=True)
40+
@dataclass(slots=True)
4241
class DynamicTensorSpec:
4342
"""
4443
A specification for a dynamic tensor dimension.
@@ -52,11 +51,37 @@ class DynamicTensorSpec:
5251
tensor_initializers: A list of functions to initialize the tensors.
5352
"""
5453

55-
input_idx: Tuple[int]
56-
dim_idx: Tuple[int]
57-
gen_tuning_buckets: Union[Tuple[int], Callable]
54+
input_idx: Tuple[int, ...]
55+
dim_idx: Tuple[int, ...]
56+
gen_tuning_buckets: Union[Tuple[int, ...], Callable]
5857
map_to_tuning_buckets: Callable
59-
# tensor_initializers: Tuple[Callable] = field(default_factory=lambda: [lambda shapes, dtype, device: torch.randn(shapes, device=device, dtype=dtype)])
58+
tensor_initializers: List[Callable] = field(
59+
default_factory=lambda: None
60+
)
61+
62+
def __post_init__(self):
63+
# Set default tensor_initializers if not provided
64+
if self.tensor_initializers is None:
65+
self.tensor_initializers = [
66+
lambda shapes, dtype, device: torch.randn(
67+
shapes, device=device, dtype=dtype
68+
)
69+
for _ in range(len(self.input_idx))
70+
]
71+
72+
def __hash__(self) -> int:
73+
# FIXME: currently not hasing tensor_initializers
74+
return hash(
75+
(
76+
self.input_idx,
77+
self.dim_idx,
78+
# For gen_tuning_buckets, only hash if it's a tuple, otherwise hash its id
79+
self.gen_tuning_buckets
80+
if isinstance(self.gen_tuning_buckets, tuple)
81+
else id(self.gen_tuning_buckets),
82+
id(self.map_to_tuning_buckets),
83+
)
84+
)
6085

6186

6287
@dataclass(slots=True, unsafe_hash=True)
@@ -146,6 +171,7 @@ class OptimizationProfile:
146171
"""Ranges of all tensors, all dimension"""
147172

148173
shapes: List[List[Dim]]
174+
tensor_initializers: List[Optional[Callable]]
149175

150176
def get_hash_key(self):
151177
return self.get_opt_shapes()
@@ -585,7 +611,8 @@ def _generate_optimization_profiles(
585611
else [StaticDim(0)]
586612
)
587613
for t in inputs
588-
]
614+
],
615+
[None] * len(inputs),
589616
)
590617

591618
generated_profiles: List[OptimizationProfile] = []
@@ -599,8 +626,14 @@ def _generate_optimization_profiles(
599626
"The given dynamic dimension must provide a opt value generation function or a list of opt values"
600627
)
601628
assert len(spec.input_idx) == len(spec.dim_idx), (
602-
"The number of input indices and dimension indices must be the same"
629+
f"The number of input indices and dimension indices must be the same, got {len(spec.input_idx)} and {len(spec.dim_idx)}"
603630
)
631+
assert len(spec.tensor_initializers) == len(spec.input_idx), (
632+
f"The number of tensor initializers and input indices must be the same, got {len(spec.tensor_initializers)} and {len(spec.input_idx)}"
633+
)
634+
for i, idx in enumerate(spec.input_idx):
635+
base_profile.tensor_initializers[idx] = spec.tensor_initializers[i]
636+
604637
if inspect.isfunction(spec.gen_tuning_buckets):
605638
opt_shapes = spec.gen_tuning_buckets(
606639
base_profile.shapes[spec.input_idx[0]][spec.dim_idx[0]]._opt()
@@ -720,15 +753,16 @@ def _create_tensor_like(
720753
def _prepare_input_tensors(
721754
self, profile: OptimizationProfile, inputs: List[torch.Tensor]
722755
) -> List[torch.Tensor]:
756+
default_initializer = lambda shapes, dtype, device: torch.rand(
757+
shapes, device=device
758+
).to(dtype)
723759
tensors = []
724760
for i, p in enumerate(profile.shapes):
725761
if any(isinstance(d, DynamicDim) for d in p):
726762
tensor = self._create_tensor_like(
727763
inputs[i],
728764
p,
729-
lambda shapes, dtype, device: torch.rand(shapes, device=device).to(
730-
dtype
731-
),
765+
profile.tensor_initializers[i] or default_initializer,
732766
)
733767
else:
734768
tensor = inputs[i]

flashinfer/fused_moe/core.py

Lines changed: 96 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,8 @@ class MoERunner(TunableRunner):
372372
tuning_config = TuningConfig(
373373
dynamic_tensor_specs=(
374374
DynamicTensorSpec(
375-
0,
376-
0,
375+
(0,),
376+
(0,),
377377
get_last_power_of_2_num_tokens_buckets(8192),
378378
lambda x: min(last_positive_power_of_2(x), 8192),
379379
),
@@ -946,7 +946,7 @@ class MoERunner(TunableRunner):
946946
(0, 0, 0, 0, 0, 0),
947947
get_last_power_of_2_num_tokens_buckets(8192),
948948
lambda x: min(last_positive_power_of_2(x), 8192),
949-
# dynamic_tensor_initializers
949+
dynamic_tensor_initializers,
950950
),
951951
)
952952
)
@@ -957,7 +957,7 @@ class MoERunner(TunableRunner):
957957
(0, 0, 0, 0, 0),
958958
get_last_power_of_2_num_tokens_buckets(8192),
959959
lambda x: min(last_positive_power_of_2(x), 8192),
960-
# dynamic_tensor_initializers[:5]
960+
dynamic_tensor_initializers[:5],
961961
),
962962
),
963963
)
@@ -1057,6 +1057,19 @@ def forward(
10571057
intermediate_size: int,
10581058
num_local_experts: int,
10591059
num_tokens: int,
1060+
routing_bias: Optional[torch.Tensor] = None,
1061+
gemm1_bias: Optional[torch.Tensor] = None,
1062+
gemm1_alpha: Optional[torch.Tensor] = None,
1063+
gemm1_beta: Optional[torch.Tensor] = None,
1064+
gemm1_clamp_limit: Optional[torch.Tensor] = None,
1065+
gemm2_bias: Optional[torch.Tensor] = None,
1066+
output1_scale_scalar: Optional[torch.Tensor] = None,
1067+
output1_scale_gate_scalar: Optional[torch.Tensor] = None,
1068+
output2_scale_scalar: Optional[torch.Tensor] = None,
1069+
n_group: Optional[int] = None,
1070+
topk_group: Optional[int] = None,
1071+
local_expert_offset: int = 0,
1072+
routed_scaling_factor: Optional[float] = None,
10601073
routing_method_type: int = 1,
10611074
tactic: int = -1,
10621075
do_preparation: bool = False,
@@ -1098,34 +1111,34 @@ def forward(
10981111
routing_logits.to(torch.bfloat16),
10991112
topk_ids,
11001113
expert_weights,
1101-
None, # routing_bias
1114+
routing_bias,
11021115
hidden_states,
1103-
hidden_states_scale.reshape(-1), # hidden_states_scale
1116+
hidden_states_scale, # hidden_states_scale
11041117
gemm1_weights,
11051118
gemm1_weights_scale,
1106-
None, # gemm1_bias
1107-
None, # gemm1_alpha
1108-
None, # gemm1_beta
1109-
None, # gemm1_clamp_limit
1119+
gemm1_bias,
1120+
gemm1_alpha,
1121+
gemm1_beta,
1122+
gemm1_clamp_limit,
11101123
gemm2_weights,
11111124
gemm2_weights_scale,
1112-
None, # gemm2_bias
1113-
None, # output1_scale_scalar
1114-
None, # output1_scale_gate_scalar
1115-
None, # output2_scale_scalar
1125+
gemm2_bias,
1126+
output1_scale_scalar,
1127+
output1_scale_gate_scalar,
1128+
output2_scale_scalar,
11161129
num_local_experts,
11171130
self.top_k,
1118-
None, # n_group
1119-
None, # topk_group
1131+
n_group,
1132+
topk_group,
11201133
intermediate_size,
1121-
0, # local_expert_offset
1134+
local_expert_offset,
11221135
num_local_experts,
1123-
None, # routed_scaling_factor
1124-
tile_tokens_dim, # tile_tokens_dim
1125-
routing_method_type, # routing_method_type
1136+
routed_scaling_factor,
1137+
tile_tokens_dim,
1138+
routing_method_type,
11261139
True, # do_finalize
1127-
output, # output
1128-
tactic, # config_idx
1140+
output,
1141+
tactic,
11291142
)
11301143

11311144
@classmethod
@@ -1138,7 +1151,7 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
11381151
(0, 0, 0, 0, 0, 0),
11391152
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens),
11401153
lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
1141-
# cls.dynamic_tensor_initializers
1154+
cls.dynamic_tensor_initializers,
11421155
),
11431156
)
11441157
)
@@ -1149,7 +1162,7 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
11491162
(0, 0, 0, 0, 0),
11501163
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens),
11511164
lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
1152-
# cls.dynamic_tensor_initializers[:5]
1165+
cls.dynamic_tensor_initializers[:5],
11531166
),
11541167
),
11551168
)
@@ -1378,69 +1391,64 @@ def trtllm_fp4_block_scale_moe_op(
13781391
)
13791392

13801393
tuner = AutoTuner.get()
1381-
if tuner.is_tuning_mode:
1382-
MoERunner.refine_tuning_config(tune_max_num_tokens)
1383-
dtype_act = deduce_trtllm_gen_tensor_dtype(
1384-
hidden_states, hidden_states_scale
1385-
)
1386-
dtype_weights = deduce_trtllm_gen_tensor_dtype(
1387-
gemm1_weights, gemm1_weights_scale
1388-
)
1389-
moe_runner = MoERunner(
1390-
top_k=top_k,
1391-
num_experts=num_experts,
1392-
dtype_act=dtype_act,
1393-
dtype_weights=dtype_weights,
1394-
use_deepseek_fp8=False,
1395-
tile_tokens_dim=tile_tokens_dim,
1396-
tune_max_num_tokens=tune_max_num_tokens,
1397-
)
1398-
tunning_config = (
1399-
MoERunner.tuning_config_no_hidden_states_scales
1400-
if hidden_states_scale is None
1401-
else MoERunner.tuning_config_with_hidden_states_scales
1402-
)
1403-
inputs = [
1404-
output,
1405-
routing_logits,
1406-
topk_ids,
1407-
expert_weights,
1408-
hidden_states,
1409-
gemm1_weights,
1410-
gemm2_weights,
1411-
]
1412-
# hidden_states_scale should be in front of gemm1_weights_scale and gemm2_weights_scale
1413-
if hidden_states_scale is not None:
1414-
inputs.append(hidden_states_scale)
1415-
inputs.append(gemm1_weights_scale)
1416-
inputs.append(gemm2_weights_scale)
1417-
1418-
_, tactic = tuner.choose_one(
1419-
"flashinfer::trtllm_fp4_block_scale_moe",
1420-
[moe_runner],
1421-
tunning_config,
1422-
inputs,
1423-
hidden_size=hidden_size,
1424-
intermediate_size=intermediate_size,
1425-
num_local_experts=num_experts,
1426-
num_tokens=num_tokens,
1427-
routing_method_type=routing_method_type,
1428-
)
1429-
print(f"tactic: {tactic}")
1430-
default_tactic = moe_op.trtllm_get_default_moe_configs(
1431-
tile_tokens_dim,
1432-
dtype_act,
1433-
dtype_weights,
1434-
False,
1435-
top_k,
1436-
hidden_size,
1437-
intermediate_size,
1438-
num_experts,
1439-
num_tokens,
1440-
)
1441-
print(f"default_tactic: {default_tactic}")
1442-
else:
1443-
tactic = -1
1394+
MoERunner.refine_tuning_config(tune_max_num_tokens)
1395+
dtype_act = deduce_trtllm_gen_tensor_dtype(hidden_states, hidden_states_scale)
1396+
dtype_weights = deduce_trtllm_gen_tensor_dtype(
1397+
gemm1_weights, gemm1_weights_scale
1398+
)
1399+
moe_runner = MoERunner(
1400+
top_k=top_k,
1401+
num_experts=num_experts,
1402+
dtype_act=dtype_act,
1403+
dtype_weights=dtype_weights,
1404+
use_deepseek_fp8=False,
1405+
tile_tokens_dim=tile_tokens_dim,
1406+
tune_max_num_tokens=tune_max_num_tokens,
1407+
)
1408+
tunning_config = (
1409+
MoERunner.tuning_config_no_hidden_states_scales
1410+
if hidden_states_scale is None
1411+
else MoERunner.tuning_config_with_hidden_states_scales
1412+
)
1413+
inputs = [
1414+
output,
1415+
routing_logits,
1416+
topk_ids,
1417+
expert_weights,
1418+
hidden_states,
1419+
gemm1_weights,
1420+
gemm2_weights,
1421+
]
1422+
# hidden_states_scale should be in front of gemm1_weights_scale and gemm2_weights_scale
1423+
if hidden_states_scale is not None:
1424+
inputs.append(hidden_states_scale)
1425+
inputs.append(gemm1_weights_scale)
1426+
inputs.append(gemm2_weights_scale)
1427+
1428+
_, tactic = tuner.choose_one(
1429+
"flashinfer::trtllm_fp4_block_scale_moe",
1430+
[moe_runner],
1431+
tunning_config,
1432+
inputs,
1433+
hidden_size=hidden_size,
1434+
intermediate_size=intermediate_size,
1435+
num_local_experts=num_experts,
1436+
num_tokens=num_tokens,
1437+
routing_bias=routing_bias,
1438+
gemm1_bias=gemm1_bias,
1439+
gemm1_alpha=gemm1_alpha,
1440+
gemm1_beta=gemm1_beta,
1441+
gemm1_clamp_limit=gemm1_clamp_limit,
1442+
gemm2_bias=gemm2_bias,
1443+
output1_scale_scalar=output1_scale_scalar,
1444+
output1_scale_gate_scalar=output1_scale_gate_scalar,
1445+
output2_scale_scalar=output2_scale_scalar,
1446+
n_group=n_group,
1447+
topk_group=topk_group,
1448+
local_expert_offset=local_expert_offset,
1449+
routed_scaling_factor=routed_scaling_factor,
1450+
routing_method_type=routing_method_type,
1451+
)
14441452

14451453
# Call the C++ function for block scale MoE
14461454
output = moe_op.trtllm_fp4_block_scale_moe(
@@ -1449,7 +1457,7 @@ def trtllm_fp4_block_scale_moe_op(
14491457
expert_weights,
14501458
routing_bias,
14511459
hidden_states,
1452-
hidden_states_scale.reshape(-1),
1460+
hidden_states_scale,
14531461
gemm1_weights,
14541462
gemm1_weights_scale,
14551463
gemm1_bias,

0 commit comments

Comments
 (0)