Skip to content

Commit e533cac

Browse files
committed
add trtllm-gen fp4 moe autotuner
1 parent 66144d2 commit e533cac

File tree

6 files changed

+556
-58
lines changed

6 files changed

+556
-58
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -881,11 +881,10 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe_launcher(
881881
TORCH_CHECK(hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
882882
"hidden_states_scale must be fp8.");
883883

884-
TORCH_CHECK(hidden_states_scale.value().dim() == 1, "hidden_states_scale must be 1D.");
885-
TORCH_CHECK(hidden_states_scale.value().sizes()[0] ==
886-
tensorrt_llm::computeFP4LinearLayoutSFSize(args.num_tokens,
887-
args.hidden_size / sf_vec_size),
888-
"hidden_states_scale has incorrect size");
884+
TORCH_CHECK(
885+
hidden_states_scale.value().numel() == tensorrt_llm::computeFP4LinearLayoutSFSize(
886+
args.num_tokens, args.hidden_size / sf_vec_size),
887+
"hidden_states_scale has incorrect size");
889888
}
890889

891890
TORCH_CHECK(gemm1_weights.scalar_type() == torch_ext::FLOAT4_E2M1X2,
@@ -1059,7 +1058,7 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
10591058
std::optional<int64_t> n_group, std::optional<int64_t> topk_group, int64_t intermediate_size,
10601059
int64_t local_expert_offset, int64_t local_num_experts,
10611060
std::optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
1062-
int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output) {
1061+
int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output, int64_t config_index) {
10631062
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
10641063

10651064
int const num_tokens = hidden_states.sizes()[0];
@@ -1112,8 +1111,10 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
11121111
mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim,
11131112
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
11141113

1115-
auto const moeConfigIndex = mRunner->getDefaultValidConfigIndex(
1116-
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
1114+
if (config_index == -1) {
1115+
config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1116+
local_num_experts, num_tokens);
1117+
}
11171118

11181119
return trtllm_fp4_block_scale_moe_launcher(
11191120
routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale,
@@ -1122,7 +1123,84 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
11221123
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
11231124
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor,
11241125
tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights,
1125-
moeConfigIndex, enable_pdl, output);
1126+
config_index, enable_pdl, output);
1127+
}
1128+
1129+
inline btg::Dtype get_dtype(int64_t const dtype) {
1130+
switch (dtype) {
1131+
case 0:
1132+
return btg::Dtype::Bfloat16;
1133+
case 1:
1134+
return btg::Dtype::Bool;
1135+
case 2:
1136+
return btg::Dtype::E2m1;
1137+
case 3:
1138+
return btg::Dtype::E2m3;
1139+
case 4:
1140+
return btg::Dtype::E3m2;
1141+
case 5:
1142+
return btg::Dtype::E4m3;
1143+
case 6:
1144+
return btg::Dtype::E5m2;
1145+
case 7:
1146+
return btg::Dtype::Fp16;
1147+
case 8:
1148+
return btg::Dtype::Fp32;
1149+
case 9:
1150+
return btg::Dtype::Int8;
1151+
case 10:
1152+
return btg::Dtype::Int32;
1153+
case 11:
1154+
return btg::Dtype::Int64;
1155+
case 12:
1156+
return btg::Dtype::MxE2m1;
1157+
case 13:
1158+
return btg::Dtype::MxE4m3;
1159+
case 14:
1160+
return btg::Dtype::UE8m0;
1161+
case 15:
1162+
return btg::Dtype::UInt8;
1163+
case 16:
1164+
return btg::Dtype::UInt16;
1165+
case 17:
1166+
return btg::Dtype::UInt32;
1167+
case 18:
1168+
return btg::Dtype::UInt64;
1169+
case 19:
1170+
return btg::Dtype::UInt128;
1171+
case 20:
1172+
return btg::Dtype::Void;
1173+
default:
1174+
TORCH_CHECK(false, "Invalid trtllm-gen dtype");
1175+
}
1176+
return btg::Dtype::E2m1;
1177+
}
1178+
1179+
int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_,
1180+
int64_t const dtype_weights_, bool const useDeepSeekFp8,
1181+
int64_t const top_k, int64_t const hidden_size,
1182+
int64_t const intermediate_size,
1183+
int64_t const num_local_experts, int64_t const num_tokens) {
1184+
auto dtype_act = get_dtype(dtype_act_);
1185+
auto dtype_weights = get_dtype(dtype_weights_);
1186+
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1187+
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1188+
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1189+
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1190+
num_local_experts, num_tokens);
1191+
}
1192+
1193+
std::vector<int64_t> trtllm_get_valid_moe_configs(
1194+
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
1195+
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
1196+
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) {
1197+
auto dtype_act = get_dtype(dtype_act_);
1198+
auto dtype_weights = get_dtype(dtype_weights_);
1199+
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1200+
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1201+
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1202+
return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts,
1203+
num_tokens);
11261204
}
11271205

11281206
namespace trtllm_cubin_loader {
@@ -1133,6 +1211,8 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
11331211
m.def("trtllm_fp8_per_tensor_scale_moe", trtllm_fp8_per_tensor_scale_moe);
11341212
m.def("trtllm_fp8_block_scale_moe", trtllm_fp8_block_scale_moe);
11351213
m.def("trtllm_fp4_block_scale_moe", trtllm_fp4_block_scale_moe);
1214+
m.def("trtllm_get_default_moe_configs", trtllm_get_default_moe_configs);
1215+
m.def("trtllm_get_valid_moe_configs", trtllm_get_valid_moe_configs);
11361216
}
11371217

11381218
} // namespace flashinfer

flashinfer/autotuner.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass, field
99
from functools import lru_cache
10-
from typing import Any, Callable, Dict, List, Set, Tuple, Union
10+
from typing import Any, Callable, Dict, List, Set, Tuple, Union, Optional
1111

1212
import torch
1313

@@ -37,21 +37,49 @@ def get_config_path(is_module: bool):
3737
)
3838

3939

40-
@dataclass(slots=True, unsafe_hash=True)
40+
@dataclass(slots=True)
4141
class DynamicTensorSpec:
4242
"""
4343
A specification for a dynamic tensor dimension.
4444
Args:
45-
input_idx: The index of the input tensor.
46-
dim_idx: The index of the dimension to tune.
45+
input_idx: A list of the indices of the input tensors.
46+
dim_idx: A list of the indices of the dimensions to tune.
47+
The length of input_idx and dim_idx must be the same.
48+
For every tensor mapped to the input_idx, their dimension mapped to the dim_idx must be the same.
4749
gen_tuning_buckets: A tuple of values to try or a function generating values.
4850
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
51+
tensor_initializers: A list of functions to initialize the tensors.
4952
"""
5053

51-
input_idx: int
52-
dim_idx: int
53-
gen_tuning_buckets: Union[Tuple[int], Callable]
54+
input_idx: Tuple[int, ...]
55+
dim_idx: Tuple[int, ...]
56+
gen_tuning_buckets: Union[Tuple[int, ...], Callable]
5457
map_to_tuning_buckets: Callable
58+
tensor_initializers: List[Callable] = field(default_factory=lambda: None)
59+
60+
def __post_init__(self):
61+
# Set default tensor_initializers if not provided
62+
if self.tensor_initializers is None:
63+
self.tensor_initializers = [
64+
lambda shapes, dtype, device: torch.randn(
65+
shapes, device=device, dtype=dtype
66+
)
67+
for _ in range(len(self.input_idx))
68+
]
69+
70+
def __hash__(self) -> int:
71+
# FIXME: currently not hasing tensor_initializers
72+
return hash(
73+
(
74+
self.input_idx,
75+
self.dim_idx,
76+
# For gen_tuning_buckets, only hash if it's a tuple, otherwise hash its id
77+
self.gen_tuning_buckets
78+
if isinstance(self.gen_tuning_buckets, tuple)
79+
else id(self.gen_tuning_buckets),
80+
id(self.map_to_tuning_buckets),
81+
)
82+
)
5583

5684

5785
@dataclass(slots=True, unsafe_hash=True)
@@ -85,8 +113,8 @@ class TuningConfig:
85113
>>> config = TuningConfig(
86114
... dynamic_tensor_specs=(
87115
... DynamicTensorSpec(
88-
... input_idx=0,
89-
... dim_idx=1,
116+
... input_idx=[0],
117+
... dim_idx=[1],
90118
... gen_tuning_buckets=(32, 64, 128),
91119
... map_to_tuning_buckets=lambda x: ((x + 31) // 32) * 32
92120
... ),
@@ -141,6 +169,7 @@ class OptimizationProfile:
141169
"""Ranges of all tensors, all dimension"""
142170

143171
shapes: List[List[Dim]]
172+
tensor_initializers: List[Optional[Callable]]
144173

145174
def get_hash_key(self):
146175
return self.get_opt_shapes()
@@ -426,7 +455,7 @@ def choose_one(
426455
"All Given runners must be subclass of TunableRunner"
427456
)
428457

429-
profiles = self._optimization_profiles(tuning_config, inputs)
458+
profiles = self._generate_optimization_profiles(tuning_config, inputs)
430459
# Record the total configs to try
431460
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
432461

@@ -532,7 +561,8 @@ def _profile_single_kernel(
532561
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
533562
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
534563
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
535-
delay_kernel(self.stream_delay_micro_secs)
564+
if self.stream_delay_micro_secs > 0:
565+
delay_kernel(self.stream_delay_micro_secs)
536566
start = torch.cuda.Event(enable_timing=True)
537567
end = torch.cuda.Event(enable_timing=True)
538568

@@ -551,7 +581,7 @@ def _profile_single_kernel(
551581

552582
return avg_time
553583

554-
def _optimization_profiles(
584+
def _generate_optimization_profiles(
555585
self, tuning_config: TuningConfig, inputs: List[torch.Tensor]
556586
) -> List[OptimizationProfile]:
557587
"""Generate optimization profiles for autotuning.
@@ -579,7 +609,8 @@ def _optimization_profiles(
579609
else [StaticDim(0)]
580610
)
581611
for t in inputs
582-
]
612+
],
613+
[None] * len(inputs),
583614
)
584615

585616
generated_profiles: List[OptimizationProfile] = []
@@ -592,9 +623,18 @@ def _optimization_profiles(
592623
), (
593624
"The given dynamic dimension must provide a opt value generation function or a list of opt values"
594625
)
626+
assert len(spec.input_idx) == len(spec.dim_idx), (
627+
f"The number of input indices and dimension indices must be the same, got {len(spec.input_idx)} and {len(spec.dim_idx)}"
628+
)
629+
assert len(spec.tensor_initializers) == len(spec.input_idx), (
630+
f"The number of tensor initializers and input indices must be the same, got {len(spec.tensor_initializers)} and {len(spec.input_idx)}"
631+
)
632+
for i, idx in enumerate(spec.input_idx):
633+
base_profile.tensor_initializers[idx] = spec.tensor_initializers[i]
634+
595635
if inspect.isfunction(spec.gen_tuning_buckets):
596636
opt_shapes = spec.gen_tuning_buckets(
597-
base_profile.shapes[spec.input_idx][spec.dim_idx]._opt()
637+
base_profile.shapes[spec.input_idx[0]][spec.dim_idx[0]]._opt()
598638
)
599639
else:
600640
opt_shapes = spec.gen_tuning_buckets
@@ -617,9 +657,10 @@ def _optimization_profiles(
617657
# TODO: fix me, how to set the min and max?
618658
min_value = opt_value
619659
max_value = opt_shapes_max[opt_value]
620-
p.shapes[input_idx][dim_idx] = DynamicDim(
621-
min_value, opt_value, max_value
622-
)
660+
for i in range(len(input_idx)):
661+
p.shapes[input_idx[i]][dim_idx[i]] = DynamicDim(
662+
min_value, opt_value, max_value
663+
)
623664

624665
# Adjust the profile to satisfy the constraints
625666
for constraint_spec in tuning_config.constraint_specs:
@@ -653,14 +694,15 @@ def _find_nearest_profile(
653694
base_profile = list(list(shape) for shape in shapes)
654695

655696
for spec in tuning_config.dynamic_tensor_specs:
656-
base_profile[spec.input_idx][spec.dim_idx] = spec.map_to_tuning_buckets(
657-
base_profile[spec.input_idx][spec.dim_idx]
697+
base_profile[spec.input_idx[0]][spec.dim_idx[0]] = (
698+
spec.map_to_tuning_buckets(
699+
base_profile[spec.input_idx[0]][spec.dim_idx[0]]
700+
)
658701
)
659702

660703
# associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile
661704
for constraint_spec in tuning_config.constraint_specs:
662705
base_profile[constraint_spec.input_idx][constraint_spec.dim_idx] = -1
663-
664706
return tuple(tuple(shape) for shape in base_profile)
665707

666708
@classmethod
@@ -679,7 +721,7 @@ def _get_cache_key(
679721
)
680722

681723
def _create_tensor_like(
682-
self, origin_tensor: torch.Tensor, dims: List[Dim]
724+
self, origin_tensor: torch.Tensor, dims: List[Dim], initializer: Callable
683725
) -> torch.Tensor:
684726
"""Create a new tensor matching the properties of the original tensor.
685727
@@ -704,18 +746,22 @@ def _create_tensor_like(
704746
# TODO: how to make sure the created Tensor has the min/max info
705747
assert isinstance(d, DynamicDim)
706748
shapes.append(d.opt)
707-
# TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE
708-
# One solution is to manituplate the tensor content to make it more like the real data
709-
# during the tuning process. This can by controlled in the preparation phase by the runner.
710-
return torch.zeros(shapes, dtype=dtype, device=device)
749+
return initializer(shapes, dtype, device)
711750

712751
def _prepare_input_tensors(
713752
self, profile: OptimizationProfile, inputs: List[torch.Tensor]
714753
) -> List[torch.Tensor]:
754+
default_initializer = lambda shapes, dtype, device: torch.rand(
755+
shapes, device=device
756+
).to(dtype)
715757
tensors = []
716758
for i, p in enumerate(profile.shapes):
717759
if any(isinstance(d, DynamicDim) for d in p):
718-
tensor = self._create_tensor_like(inputs[i], p)
760+
tensor = self._create_tensor_like(
761+
inputs[i],
762+
p,
763+
profile.tensor_initializers[i] or default_initializer,
764+
)
719765
else:
720766
tensor = inputs[i]
721767
tensors.append(tensor)

0 commit comments

Comments
 (0)