Skip to content

Commit f590bf2

Browse files
committed
WIP
1 parent c6edf1a commit f590bf2

File tree

3 files changed

+125
-111
lines changed

3 files changed

+125
-111
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
10791079
auto mDtypeAct = btg::Dtype::Bfloat16;
10801080
if (hidden_states.scalar_type() == torch_ext::FLOAT4_E2M1X2) {
10811081
TORCH_CHECK(hidden_states_scale.has_value() &&
1082-
hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
1082+
hidden_states_scale.value().scalar_type() == at::ScalarType::Byte,
10831083
"hidden_states_scale must be provided for fp4 activation.");
10841084
if (hidden_states_scale_vec_size == 16) {
10851085
mDtypeAct = btg::Dtype::E2m1;
@@ -1171,6 +1171,20 @@ inline btg::Dtype get_dtype(int64_t const dtype) {
11711171
return btg::Dtype::E2m1;
11721172
}
11731173

1174+
int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_,
1175+
int64_t const dtype_weights_, bool const useDeepSeekFp8,
1176+
int64_t const top_k, int64_t const hidden_size,
1177+
int64_t const intermediate_size,
1178+
int64_t const num_local_experts, int64_t const num_tokens) {
1179+
auto dtype_act = get_dtype(dtype_act_);
1180+
auto dtype_weights = get_dtype(dtype_weights_);
1181+
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
1182+
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
1183+
tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true);
1184+
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
1185+
num_local_experts, num_tokens);
1186+
}
1187+
11741188
std::vector<int64_t> trtllm_get_valid_moe_configs(
11751189
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
11761190
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
@@ -1192,6 +1206,7 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
11921206
m.def("trtllm_fp8_per_tensor_scale_moe", trtllm_fp8_per_tensor_scale_moe);
11931207
m.def("trtllm_fp8_block_scale_moe", trtllm_fp8_block_scale_moe);
11941208
m.def("trtllm_fp4_block_scale_moe", trtllm_fp4_block_scale_moe);
1209+
m.def("trtllm_get_default_moe_configs", trtllm_get_default_moe_configs);
11951210
m.def("trtllm_get_valid_moe_configs", trtllm_get_valid_moe_configs);
11961211
}
11971212

flashinfer/autotuner.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from abc import ABC, abstractmethod
88
from dataclasses import dataclass, field
99
from functools import lru_cache
10-
from typing import Any, Callable, Dict, List, Set, Tuple, Union
10+
from typing import Any, Callable, Dict, List, Set, Tuple, Union, Optional
1111

1212
import torch
13+
from tqdm import tqdm
1314

1415
# from tensorrt_llm.bindings.internal.runtime import delay_kernel
1516
# from tensorrt_llm.logger import logger
@@ -42,16 +43,20 @@ class DynamicTensorSpec:
4243
"""
4344
A specification for a dynamic tensor dimension.
4445
Args:
45-
input_idx: The index of the input tensor.
46-
dim_idx: The index of the dimension to tune.
46+
input_idx: A list of the indices of the input tensors.
47+
dim_idx: A list of the indices of the dimensions to tune.
48+
The length of input_idx and dim_idx must be the same.
49+
For every tensor mapped to the input_idx, their dimension mapped to the dim_idx must be the same.
4750
gen_tuning_buckets: A tuple of values to try or a function generating values.
4851
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
52+
tensor_initializers: A list of functions to initialize the tensors.
4953
"""
5054

51-
input_idx: int
52-
dim_idx: int
55+
input_idx: Tuple[int]
56+
dim_idx: Tuple[int]
5357
gen_tuning_buckets: Union[Tuple[int], Callable]
5458
map_to_tuning_buckets: Callable
59+
# tensor_initializers: Tuple[Callable] = field(default_factory=lambda: [lambda shapes, dtype, device: torch.randn(shapes, device=device, dtype=dtype)])
5560

5661

5762
@dataclass(slots=True, unsafe_hash=True)
@@ -85,8 +90,8 @@ class TuningConfig:
8590
>>> config = TuningConfig(
8691
... dynamic_tensor_specs=(
8792
... DynamicTensorSpec(
88-
... input_idx=0,
89-
... dim_idx=1,
93+
... input_idx=[0],
94+
... dim_idx=[1],
9095
... gen_tuning_buckets=(32, 64, 128),
9196
... map_to_tuning_buckets=lambda x: ((x + 31) // 32) * 32
9297
... ),
@@ -426,7 +431,7 @@ def choose_one(
426431
"All Given runners must be subclass of TunableRunner"
427432
)
428433

429-
profiles = self._optimization_profiles(tuning_config, inputs)
434+
profiles = self._generate_optimization_profiles(tuning_config, inputs)
430435
# Record the total configs to try
431436
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
432437

@@ -532,7 +537,8 @@ def _profile_single_kernel(
532537
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
533538
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
534539
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
535-
delay_kernel(self.stream_delay_micro_secs)
540+
if self.stream_delay_micro_secs > 0:
541+
delay_kernel(self.stream_delay_micro_secs)
536542
start = torch.cuda.Event(enable_timing=True)
537543
end = torch.cuda.Event(enable_timing=True)
538544

@@ -551,7 +557,7 @@ def _profile_single_kernel(
551557

552558
return avg_time
553559

554-
def _optimization_profiles(
560+
def _generate_optimization_profiles(
555561
self, tuning_config: TuningConfig, inputs: List[torch.Tensor]
556562
) -> List[OptimizationProfile]:
557563
"""Generate optimization profiles for autotuning.
@@ -592,9 +598,12 @@ def _optimization_profiles(
592598
), (
593599
"The given dynamic dimension must provide a opt value generation function or a list of opt values"
594600
)
601+
assert len(spec.input_idx) == len(spec.dim_idx), (
602+
"The number of input indices and dimension indices must be the same"
603+
)
595604
if inspect.isfunction(spec.gen_tuning_buckets):
596605
opt_shapes = spec.gen_tuning_buckets(
597-
base_profile.shapes[spec.input_idx][spec.dim_idx]._opt()
606+
base_profile.shapes[spec.input_idx[0]][spec.dim_idx[0]]._opt()
598607
)
599608
else:
600609
opt_shapes = spec.gen_tuning_buckets
@@ -617,9 +626,10 @@ def _optimization_profiles(
617626
# TODO: fix me, how to set the min and max?
618627
min_value = opt_value
619628
max_value = opt_shapes_max[opt_value]
620-
p.shapes[input_idx][dim_idx] = DynamicDim(
621-
min_value, opt_value, max_value
622-
)
629+
for i in range(len(input_idx)):
630+
p.shapes[input_idx[i]][dim_idx[i]] = DynamicDim(
631+
min_value, opt_value, max_value
632+
)
623633

624634
# Adjust the profile to satisfy the constraints
625635
for constraint_spec in tuning_config.constraint_specs:
@@ -653,14 +663,15 @@ def _find_nearest_profile(
653663
base_profile = list(list(shape) for shape in shapes)
654664

655665
for spec in tuning_config.dynamic_tensor_specs:
656-
base_profile[spec.input_idx][spec.dim_idx] = spec.map_to_tuning_buckets(
657-
base_profile[spec.input_idx][spec.dim_idx]
666+
base_profile[spec.input_idx[0]][spec.dim_idx[0]] = (
667+
spec.map_to_tuning_buckets(
668+
base_profile[spec.input_idx[0]][spec.dim_idx[0]]
669+
)
658670
)
659671

660672
# associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile
661673
for constraint_spec in tuning_config.constraint_specs:
662674
base_profile[constraint_spec.input_idx][constraint_spec.dim_idx] = -1
663-
664675
return tuple(tuple(shape) for shape in base_profile)
665676

666677
@classmethod
@@ -679,7 +690,7 @@ def _get_cache_key(
679690
)
680691

681692
def _create_tensor_like(
682-
self, origin_tensor: torch.Tensor, dims: List[Dim]
693+
self, origin_tensor: torch.Tensor, dims: List[Dim], initializer: Callable
683694
) -> torch.Tensor:
684695
"""Create a new tensor matching the properties of the original tensor.
685696
@@ -704,26 +715,21 @@ def _create_tensor_like(
704715
# TODO: how to make sure the created Tensor has the min/max info
705716
assert isinstance(d, DynamicDim)
706717
shapes.append(d.opt)
707-
# TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE
708-
# One solution is to manituplate the tensor content to make it more like the real data
709-
# during the tuning process. This can by controlled in the preparation phase by the runner.
710-
# return torch.zeros(shapes, dtype=dtype, device=device)
711-
if dtype == torch.int8:
712-
return torch.randint(0, 127, shapes, dtype=dtype, device=device)
713-
elif dtype == torch.uint8:
714-
return torch.randint(0, 255, shapes, dtype=dtype, device=device)
715-
elif dtype == torch.int32:
716-
return torch.randint(0, 1000000, shapes, dtype=dtype, device=device)
717-
else:
718-
return torch.randn(shapes, dtype=dtype, device=device)
718+
return initializer(shapes, dtype, device)
719719

720720
def _prepare_input_tensors(
721721
self, profile: OptimizationProfile, inputs: List[torch.Tensor]
722722
) -> List[torch.Tensor]:
723723
tensors = []
724724
for i, p in enumerate(profile.shapes):
725725
if any(isinstance(d, DynamicDim) for d in p):
726-
tensor = self._create_tensor_like(inputs[i], p)
726+
tensor = self._create_tensor_like(
727+
inputs[i],
728+
p,
729+
lambda shapes, dtype, device: torch.rand(shapes, device=device).to(
730+
dtype
731+
),
732+
)
727733
else:
728734
tensor = inputs[i]
729735
tensors.append(tensor)

0 commit comments

Comments
 (0)