diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py new file mode 100644 index 000000000..c60cf0fc1 --- /dev/null +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -0,0 +1,227 @@ +import argparse +from typing import Optional, Literal +import torch +import numpy as np +from flashinfer import ( + fp4_quantize, + mxfp8_quantize, + next_positive_power_of_2, +) +from flashinfer.fused_moe import trtllm_fp4_block_scale_moe +from flashinfer.autotuner import autotune +from flashinfer.testing.utils import bench_gpu_time +from flashinfer.utils import device_support_pdl + + +def get_tile_tokens_dim(num_tokens, num_experts, top_k): + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def bench_trtllm_gen_fused_moe_autotuner( + tune_max_num_tokens: Optional[int], + quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], + num_tokens: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + warmups: int, + iterations: int, +): + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) + hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( + torch.bfloat16 + ) + if quant_mode == "NvFP4xNvFP4": + hidden_states, hidden_states_scale = fp4_quantize( + hidden_states, + torch.tensor([448.0 * 6.0], device=device), + sf_vec_size=16, + sf_use_ue8m0=False, + ) + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( + num_tokens, -1 + ) + hidden_states_global_scale = 1.0 / 448.0 / 6.0 + elif quant_mode == "MxFP4xMxFP8": + hidden_states, hidden_states_scale = mxfp8_quantize(hidden_states, False) + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( + num_tokens, -1 + ) + hidden_states_global_scale = 1.0 + else: # MxFP4xBf16 + hidden_states_scale = None + hidden_states_global_scale = 1.0 + + w13 = torch.randn( + num_experts, intermediate_size * 2, hidden_size, device=device + ).to(torch.bfloat16) + w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( + torch.bfloat16 + ) + if quant_mode == "NvFP4xNvFP4": + w13, w13_scale = fp4_quantize( + w13, + torch.tensor([448.0 * 6.0], device=device), + sf_vec_size=16, + sf_use_ue8m0=False, + ) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, -1 + ) + w2, w2_scale = fp4_quantize( + w2, + torch.tensor([448.0 * 6.0], device=device), + sf_vec_size=16, + sf_use_ue8m0=False, + ) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, -1 + ) + w13_global_scale = 1.0 / 448.0 / 6.0 + w2_global_scale = 1.0 / 448.0 / 6.0 + else: + w13, w13_scale = fp4_quantize( + w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True + ) + w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( + num_experts, intermediate_size * 2, -1 + ) + w2, w2_scale = fp4_quantize( + w2, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True + ) + w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, -1 + ) + w13_global_scale = 1.0 + w2_global_scale = 1.0 + bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 + bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 + + tile_tokens_dim = get_tile_tokens_dim(num_tokens, num_experts, top_k) + output1_scale_scalar = torch.tensor( + [hidden_states_global_scale * w13_global_scale] * num_experts, device=device + ) + output1_scale_gate_scalar = torch.tensor( + [hidden_states_global_scale * w13_global_scale] * num_experts, device=device + ) + output2_scale_scalar = torch.tensor( + [hidden_states_global_scale * w2_global_scale] * num_experts, device=device + ) + fn = lambda: trtllm_fp4_block_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + hidden_states_scale, + w13, + w13_scale, + bias13, + None, # gemm1_alpha + None, # gemm1_beta + None, # gemm1_clamp_limit + w2, + w2_scale, + bias2, + output1_scale_scalar, + output1_scale_gate_scalar, + output2_scale_scalar, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + None, # routed_scaling_factor + tile_tokens_dim, + 1, + True, + enable_pdl, + None, + num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + ) + + def bench(do_autotune): + # warmup + with autotune(do_autotune): + for _ in range(warmups): + fn() + ms_list = bench_gpu_time( + fn, + repeat_iters=iterations, + ) + median_ms = np.median(ms_list) + return median_ms + + ms = bench(do_autotune=False) + ms_tuned = bench(do_autotune=True) + print( + f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}" + ) + print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--quant-mode", + type=str, + default="MxFP4xMxFP8", + choices=["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], + help="Quantization mode", + ) + parser.add_argument("--num-tokens", type=int, default=512, help="Number of tokens") + parser.add_argument( + "--tune-max-num-tokens", + type=int, + default=None, + help="Maximum number of tokens for tunning", + ) + parser.add_argument( + "--num-experts", type=int, default=128, help="Number of experts" + ) + parser.add_argument("--hidden-size", type=int, default=3072, help="Hidden size") + parser.add_argument( + "--intermediate-size", type=int, default=3072, help="Intermediate size" + ) + parser.add_argument("--top-k", type=int, default=4, help="Top-k experts per token") + parser.add_argument( + "--warmups", type=int, default=100, help="Number of warmup iterations" + ) + parser.add_argument( + "--iterations", type=int, default=100, help="Number of benchmark iterations" + ) + args = parser.parse_args() + bench_trtllm_gen_fused_moe_autotuner( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + ) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 0269083ed..524e11ccb 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -881,11 +881,10 @@ std::vector trtllm_fp4_block_scale_moe_launcher( TORCH_CHECK(hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states_scale must be fp8."); - TORCH_CHECK(hidden_states_scale.value().dim() == 1, "hidden_states_scale must be 1D."); - TORCH_CHECK(hidden_states_scale.value().sizes()[0] == - tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, - args.hidden_size / sf_vec_size), - "hidden_states_scale has incorrect size"); + TORCH_CHECK( + hidden_states_scale.value().numel() == tensorrt_llm::computeLinearLayoutSFSize( + args.num_tokens, args.hidden_size / sf_vec_size), + "hidden_states_scale has incorrect size"); } TORCH_CHECK(gemm1_weights.scalar_type() == torch_ext::FLOAT4_E2M1X2, @@ -1059,7 +1058,8 @@ std::vector trtllm_fp4_block_scale_moe( std::optional n_group, std::optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, std::optional routed_scaling_factor, int64_t tile_tokens_dim, - int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output) { + int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output, + int64_t config_index) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; int const num_tokens = hidden_states.sizes()[0]; @@ -1112,8 +1112,10 @@ std::vector trtllm_fp4_block_scale_moe( mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim, tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true); - auto const moeConfigIndex = mRunner->getDefaultValidConfigIndex( - top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); + if (config_index == -1) { + config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + local_num_experts, num_tokens); + } return trtllm_fp4_block_scale_moe_launcher( routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale, @@ -1122,7 +1124,34 @@ std::vector trtllm_fp4_block_scale_moe( output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights, - moeConfigIndex, enable_pdl, output); + config_index, enable_pdl, output); +} + +int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_, + int64_t const dtype_weights_, bool const useDeepSeekFp8, + int64_t const top_k, int64_t const hidden_size, + int64_t const intermediate_size, + int64_t const num_local_experts, int64_t const num_tokens) { + auto dtype_act = static_cast(dtype_act_); + auto dtype_weights = static_cast(dtype_weights_); + tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( + dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, + tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true); + return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); +} + +std::vector trtllm_get_valid_moe_configs( + int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_, + bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size, + int64_t const intermediate_size, int64_t const num_local_experts, int64_t const num_tokens) { + auto dtype_act = static_cast(dtype_act_); + auto dtype_weights = static_cast(dtype_weights_); + tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( + dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, + tensorrt_llm::kernels::ActType::SwiGlu, /*useShuffledMatrixA*/ true); + return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, + num_tokens); } namespace trtllm_cubin_loader { @@ -1133,6 +1162,8 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("trtllm_fp8_per_tensor_scale_moe", trtllm_fp8_per_tensor_scale_moe); m.def("trtllm_fp8_block_scale_moe", trtllm_fp8_block_scale_moe); m.def("trtllm_fp4_block_scale_moe", trtllm_fp4_block_scale_moe); + m.def("trtllm_get_default_moe_configs", trtllm_get_default_moe_configs); + m.def("trtllm_get_valid_moe_configs", trtllm_get_valid_moe_configs); } } // namespace flashinfer diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index af4dac115..2e153ad26 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, Callable, Dict, List, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Set, Tuple, Union, Optional import torch @@ -37,21 +37,49 @@ def get_config_path(is_module: bool): ) -@dataclass(slots=True, unsafe_hash=True) +@dataclass(slots=True) class DynamicTensorSpec: """ A specification for a dynamic tensor dimension. Args: - input_idx: The index of the input tensor. - dim_idx: The index of the dimension to tune. + input_idx: A list of the indices of the input tensors. + dim_idx: A list of the indices of the dimensions to tune. + The length of input_idx and dim_idx must be the same. + For every tensor mapped to the input_idx, their dimension mapped to the dim_idx must be the same. gen_tuning_buckets: A tuple of values to try or a function generating values. map_to_tuning_buckets: A function to map dimensions to valid values during inference. + tensor_initializers: A list of functions to initialize the tensors. """ - input_idx: int - dim_idx: int - gen_tuning_buckets: Union[Tuple[int], Callable] + input_idx: Tuple[int, ...] + dim_idx: Tuple[int, ...] + gen_tuning_buckets: Union[Tuple[int, ...], Callable] map_to_tuning_buckets: Callable + tensor_initializers: List[Callable] = field(default_factory=lambda: None) + + def __post_init__(self): + # Set default tensor_initializers if not provided + if self.tensor_initializers is None: + self.tensor_initializers = [ + lambda shapes, dtype, device: torch.randn( + shapes, device=device, dtype=dtype + ) + for _ in range(len(self.input_idx)) + ] + + def __hash__(self) -> int: + # FIXME: currently not hasing tensor_initializers + return hash( + ( + self.input_idx, + self.dim_idx, + # For gen_tuning_buckets, only hash if it's a tuple, otherwise hash its id + self.gen_tuning_buckets + if isinstance(self.gen_tuning_buckets, tuple) + else id(self.gen_tuning_buckets), + id(self.map_to_tuning_buckets), + ) + ) @dataclass(slots=True, unsafe_hash=True) @@ -85,8 +113,8 @@ class TuningConfig: >>> config = TuningConfig( ... dynamic_tensor_specs=( ... DynamicTensorSpec( - ... input_idx=0, - ... dim_idx=1, + ... input_idx=[0], + ... dim_idx=[1], ... gen_tuning_buckets=(32, 64, 128), ... map_to_tuning_buckets=lambda x: ((x + 31) // 32) * 32 ... ), @@ -141,6 +169,7 @@ class OptimizationProfile: """Ranges of all tensors, all dimension""" shapes: List[List[Dim]] + tensor_initializers: List[Optional[Callable]] def get_hash_key(self): return self.get_opt_shapes() @@ -190,11 +219,10 @@ def __call__(self, inputs, **kwargs): @abstractmethod def forward( self, - /, # tensors are position only inputs: List[torch.Tensor], - *, # all others are keyword args only tactic: int = -1, do_preparation: bool = False, + **kwargs, # all others are keyword args only ) -> Any: """Forward pass for tunable runners. @@ -426,7 +454,7 @@ def choose_one( "All Given runners must be subclass of TunableRunner" ) - profiles = self._optimization_profiles(tuning_config, inputs) + profiles = self._generate_optimization_profiles(tuning_config, inputs) # Record the total configs to try self.stats.tuned_op_total_configs[custom_op] = len(profiles) @@ -532,7 +560,8 @@ def _profile_single_kernel( # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops) # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity. - delay_kernel(self.stream_delay_micro_secs) + if self.stream_delay_micro_secs > 0: + delay_kernel(self.stream_delay_micro_secs) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -551,7 +580,7 @@ def _profile_single_kernel( return avg_time - def _optimization_profiles( + def _generate_optimization_profiles( self, tuning_config: TuningConfig, inputs: List[torch.Tensor] ) -> List[OptimizationProfile]: """Generate optimization profiles for autotuning. @@ -579,7 +608,8 @@ def _optimization_profiles( else [StaticDim(0)] ) for t in inputs - ] + ], + [None] * len(inputs), ) generated_profiles: List[OptimizationProfile] = [] @@ -592,9 +622,18 @@ def _optimization_profiles( ), ( "The given dynamic dimension must provide a opt value generation function or a list of opt values" ) + assert len(spec.input_idx) == len(spec.dim_idx), ( + f"The number of input indices and dimension indices must be the same, got {len(spec.input_idx)} and {len(spec.dim_idx)}" + ) + assert len(spec.tensor_initializers) == len(spec.input_idx), ( + f"The number of tensor initializers and input indices must be the same, got {len(spec.tensor_initializers)} and {len(spec.input_idx)}" + ) + for i, idx in enumerate(spec.input_idx): + base_profile.tensor_initializers[idx] = spec.tensor_initializers[i] + if inspect.isfunction(spec.gen_tuning_buckets): opt_shapes = spec.gen_tuning_buckets( - base_profile.shapes[spec.input_idx][spec.dim_idx]._opt() + base_profile.shapes[spec.input_idx[0]][spec.dim_idx[0]]._opt() ) else: opt_shapes = spec.gen_tuning_buckets @@ -617,9 +656,10 @@ def _optimization_profiles( # TODO: fix me, how to set the min and max? min_value = opt_value max_value = opt_shapes_max[opt_value] - p.shapes[input_idx][dim_idx] = DynamicDim( - min_value, opt_value, max_value - ) + for i in range(len(input_idx)): + p.shapes[input_idx[i]][dim_idx[i]] = DynamicDim( + min_value, opt_value, max_value + ) # Adjust the profile to satisfy the constraints for constraint_spec in tuning_config.constraint_specs: @@ -653,14 +693,15 @@ def _find_nearest_profile( base_profile = list(list(shape) for shape in shapes) for spec in tuning_config.dynamic_tensor_specs: - base_profile[spec.input_idx][spec.dim_idx] = spec.map_to_tuning_buckets( - base_profile[spec.input_idx][spec.dim_idx] + base_profile[spec.input_idx[0]][spec.dim_idx[0]] = ( + spec.map_to_tuning_buckets( + base_profile[spec.input_idx[0]][spec.dim_idx[0]] + ) ) # associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile for constraint_spec in tuning_config.constraint_specs: base_profile[constraint_spec.input_idx][constraint_spec.dim_idx] = -1 - return tuple(tuple(shape) for shape in base_profile) @classmethod @@ -679,7 +720,7 @@ def _get_cache_key( ) def _create_tensor_like( - self, origin_tensor: torch.Tensor, dims: List[Dim] + self, origin_tensor: torch.Tensor, dims: List[Dim], initializer: Callable ) -> torch.Tensor: """Create a new tensor matching the properties of the original tensor. @@ -704,18 +745,22 @@ def _create_tensor_like( # TODO: how to make sure the created Tensor has the min/max info assert isinstance(d, DynamicDim) shapes.append(d.opt) - # TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE - # One solution is to manituplate the tensor content to make it more like the real data - # during the tuning process. This can by controlled in the preparation phase by the runner. - return torch.zeros(shapes, dtype=dtype, device=device) + return initializer(shapes, dtype, device) def _prepare_input_tensors( self, profile: OptimizationProfile, inputs: List[torch.Tensor] ) -> List[torch.Tensor]: + default_initializer = lambda shapes, dtype, device: torch.rand( + shapes, device=device + ).to(dtype) tensors = [] for i, p in enumerate(profile.shapes): if any(isinstance(d, DynamicDim) for d in p): - tensor = self._create_tensor_like(inputs[i], p) + tensor = self._create_tensor_like( + inputs[i], + p, + profile.tensor_initializers[i] or default_initializer, + ) else: tensor = inputs[i] tensors.append(tensor) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 07df2b991..f21190c43 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -45,6 +45,7 @@ from .utils import ( get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, + next_positive_power_of_2, ) @@ -65,6 +66,78 @@ class RoutingMethodType(IntEnum): Unspecified = 5 +class DtypeTrtllmGen(IntEnum): + def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid): + value = ( + (block_format_bit << 24) + | (signed_bit << 20) + | (integer_bit << 16) + | (num_bits << 8) + | uid + ) + obj = int.__new__(cls, value) + obj._value_ = value + return obj + + # keep the values in sync with include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h + Bfloat16 = (0, 1, 0, 16, 0) + Bool = (0, 0, 1, 1, 1) + E2m1 = (1, 1, 0, 4, 2) + E2m3 = (1, 1, 0, 6, 3) + E3m2 = (1, 1, 0, 6, 4) + E4m3 = (0, 1, 0, 8, 5) + E5m2 = (0, 1, 0, 8, 6) + Fp16 = (0, 1, 0, 16, 7) + Fp32 = (0, 1, 0, 32, 8) + Int8 = (0, 1, 1, 8, 9) + Int32 = (0, 1, 1, 32, 10) + Int64 = (0, 1, 1, 64, 11) + MxE2m1 = (1, 1, 0, 4, 12) + MxE4m3 = (1, 1, 0, 8, 13) + UE8m0 = (0, 0, 0, 8, 14) + UInt8 = (0, 0, 1, 8, 15) + UInt16 = (0, 0, 1, 16, 16) + UInt32 = (0, 0, 1, 32, 17) + UInt64 = (0, 0, 1, 64, 18) + UInt128 = (0, 0, 1, 128, 19) + Void = (0, 1, 0, 0, 20) + + +def trtllm_gen_dtype_has_scale(dtype: DtypeTrtllmGen) -> bool: + if dtype in [ + DtypeTrtllmGen.MxE4m3, + DtypeTrtllmGen.E2m1, + DtypeTrtllmGen.MxE2m1, + DtypeTrtllmGen.MxE4m3, + ]: + return True + else: + return False + + +def deduce_trtllm_gen_tensor_dtype( + x: torch.Tensor, scale: Optional[torch.Tensor] +) -> DtypeTrtllmGen: + hidden_size = x.shape[-1] + if x.dtype == torch.uint8: # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 + hidden_size *= 2 + if x.dtype == torch.bfloat16: + dtype = DtypeTrtllmGen.Bfloat16 + elif x.dtype == torch.float8_e4m3fn: + dtype = DtypeTrtllmGen.E4m3 if scale is None else DtypeTrtllmGen.MxE4m3 + elif ( + x.dtype == torch.uint8 + ): # FIXME(siyuan): use torch.float4_e2m1x2 after torch 2.8 + assert scale is not None, "Scale tensor must be provided for float4x2 input" + if scale.shape[-1] == hidden_size // 16: + dtype = DtypeTrtllmGen.E2m1 + else: + dtype = DtypeTrtllmGen.MxE2m1 + else: + raise ValueError("Unsupported trtllm-gen input tensor.") + return dtype + + # See MatrixLayout from include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h class WeightLayout(IntEnum): # K-major layout (default). [Mn, K] @@ -318,8 +391,8 @@ class MoERunner(TunableRunner): tuning_config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( - 0, - 0, + (0,), + (0,), get_last_power_of_2_num_tokens_buckets(8192), lambda x: min(last_positive_power_of_2(x), 8192), ), @@ -392,9 +465,9 @@ def get_valid_tactics( def forward( self, inputs: List[torch.Tensor], - gemm_idx: int = 0, tactic: int = -1, do_preparation: bool = False, + **kwargs, ): ( x, @@ -418,7 +491,7 @@ def forward( self.cluster_rank, self.enable_alltoall, self.min_latency_mode, - gemm_idx, + kwargs["gemm_idx"], tactic, do_preparation, self.enable_pdl, @@ -430,8 +503,8 @@ def refine_tuning_config(cls, tune_max_num_tokens: int): cls.tuning_config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( - 0, - 0, + (0,), + (0,), get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens), lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), ), @@ -879,6 +952,241 @@ def get_trtllm_moe_sm100_module(): moe_op = module.build_and_load() setup_cubin_loader(str(module.get_library_path())) + class MoERunner(TunableRunner): + dynamic_tensor_initializers = [ + lambda shapes, dtype, device: torch.empty( + shapes, device=device, dtype=dtype + ), # output buffer, [num_tokens, hidden_size] + lambda shapes, dtype, device: torch.rand( + shapes, device=device, dtype=dtype + ), # routing_logits, [num_tokens, num_experts] + lambda shapes, dtype, device: torch.empty( + shapes, device=device, dtype=dtype + ), # topk_ids buffer. empty since routing_logits is used. [num_tokens, topk] + lambda shapes, dtype, device: torch.empty( + shapes, device=device, dtype=dtype + ), # expert_weights buffer. empty since routing_logits is used. [num_tokens, topk] + lambda shapes, dtype, device: torch.randn(shapes, device=device).to( + dtype + ), # hidden_states, [num_tokens, hidden_size] + lambda shapes, dtype, device: torch.ones(shapes, device=device).to( + dtype + ), # hidden_states_scale, [num_tokens, hidden_size // sf_vec_size] + ] + # their first dimension is num_tokens which will be tuned + tuning_config_with_hidden_states_scales = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0, 1, 2, 3, 4, 5), + (0, 0, 0, 0, 0, 0), + get_last_power_of_2_num_tokens_buckets(1024, 8), + lambda x: min(last_positive_power_of_2(x), 1024), + dynamic_tensor_initializers, + ), + ) + ) + tuning_config_no_hidden_states_scales = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0, 1, 2, 3, 4), + (0, 0, 0, 0, 0), + get_last_power_of_2_num_tokens_buckets(1024, 8), + lambda x: min(last_positive_power_of_2(x), 1024), + dynamic_tensor_initializers[:5], + ), + ), + ) + # cache the valid tactics to reduce the overhead of instantiating the runner + # TODO(siyuan): directly cache the runners + valid_tactics_dict = dict() + + def __init__( + self, + top_k: int, + num_experts: int, + dtype_act: DtypeTrtllmGen, + dtype_weights: DtypeTrtllmGen, + use_deepseek_fp8: bool, + hidden_size: int, + intermediate_size: int, + tile_tokens_dim: Optional[int] = None, + ): + self.num_experts = num_experts + self.top_k = top_k + self.dtype_act = dtype_act + self.dtype_weights = dtype_weights + self.use_deepseek_fp8 = use_deepseek_fp8 + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.tile_tokens_dim = tile_tokens_dim + + def get_tile_tokens_dim(self, num_tokens: int, top_k: int): + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // self.num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + ( + output, + routing_logits, + topk_ids, + expert_weights, + hidden_states, + *extra_inputs, + ) = inputs + num_tokens = routing_logits.shape[0] + tile_tokens_dim = ( + self.get_tile_tokens_dim(num_tokens, self.top_k) + if self.tile_tokens_dim is None + else self.tile_tokens_dim + ) + instance_key = ( + tile_tokens_dim, + self.dtype_act, + self.dtype_weights, + self.use_deepseek_fp8, + self.top_k, + self.hidden_size, + self.intermediate_size, + self.num_experts, + num_tokens, + ) + if instance_key not in MoERunner.valid_tactics_dict: + MoERunner.valid_tactics_dict[instance_key] = ( + moe_op.trtllm_get_valid_moe_configs(*instance_key) + ) + return MoERunner.valid_tactics_dict[instance_key] + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ): + ( + output, + routing_logits, + topk_ids, + expert_weights, + hidden_states, + *extra_inputs, + ) = inputs + num_tokens = routing_logits.shape[0] + tile_tokens_dim = ( + self.get_tile_tokens_dim(num_tokens, self.top_k) + if self.tile_tokens_dim is None + else self.tile_tokens_dim + ) + + extra_input_idx = 0 + if trtllm_gen_dtype_has_scale(self.dtype_act): + hidden_states_scale = extra_inputs[extra_input_idx] + extra_input_idx += 1 + else: + hidden_states_scale = None + # sanity checks to ensure that dynamic tensors have the correct shapes + assert output.shape[0] == num_tokens, ( + "output's first dimension must be batch size." + ) + assert topk_ids.shape[0] == num_tokens, ( + "topk_ids's first dimension must be batch size." + ) + assert expert_weights.shape[0] == num_tokens, ( + "expert_weights's first dimension must be batch size." + ) + assert hidden_states.shape[0] == num_tokens, ( + "hidden_states's first dimension must be batch size." + ) + assert hidden_states_scale is None or ( + hidden_states_scale.dim() == 2 + and hidden_states_scale.shape[0] == num_tokens + ), "hidden_states_scale's first dimension must be batch size" + + # TODO(siyuan): support fp8 + moe_op.trtllm_fp4_block_scale_moe( + routing_logits.to(torch.bfloat16), + topk_ids, + expert_weights, + kwargs["routing_bias"], + hidden_states, + hidden_states_scale, # hidden_states_scale + kwargs["gemm1_weights"], + kwargs["gemm1_weights_scale"], + kwargs["gemm1_bias"], + kwargs["gemm1_alpha"], + kwargs["gemm1_beta"], + kwargs["gemm1_clamp_limit"], + kwargs["gemm2_weights"], + kwargs["gemm2_weights_scale"], + kwargs["gemm2_bias"], + kwargs["output1_scale_scalar"], + kwargs["output1_scale_gate_scalar"], + kwargs["output2_scale_scalar"], + self.num_experts, + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + kwargs["num_local_experts"], + kwargs["routed_scaling_factor"], + tile_tokens_dim, + kwargs["routing_method_type"], + kwargs["enable_pdl"], + kwargs["do_finalize"], + output, + tactic, + ) + + @classmethod + @functools.lru_cache(maxsize=None) + def refine_tuning_config(cls, tune_max_num_tokens: int): + cls.tuning_config_with_hidden_states_scales = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0, 1, 2, 3, 4, 5), + (0, 0, 0, 0, 0, 0), + get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8), + lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), + cls.dynamic_tensor_initializers, + ), + ) + ) + cls.tuning_config_no_hidden_states_scales = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (0, 1, 2, 3, 4), + (0, 0, 0, 0, 0), + get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 8), + lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), + cls.dynamic_tensor_initializers[:5], + ), + ), + ) + @register_custom_op( "flashinfer::trtllm_fp8_per_tensor_scale_moe", mutates_args=(""), @@ -1081,6 +1389,7 @@ def trtllm_fp4_block_scale_moe_op( do_finalize: bool, enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 1024, ) -> List[torch.Tensor]: if routing_logits is None: assert topk_ids is not None, ( @@ -1114,6 +1423,67 @@ def trtllm_fp4_block_scale_moe_op( device=hidden_states.device, ) + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + dtype_act = deduce_trtllm_gen_tensor_dtype(hidden_states, hidden_states_scale) + dtype_weights = deduce_trtllm_gen_tensor_dtype( + gemm1_weights, gemm1_weights_scale + ) + moe_runner = MoERunner( + top_k=top_k, + num_experts=num_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + use_deepseek_fp8=False, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + # NOTE(siyuan): do not fix the tile_tokens_dim to let tunnable runner decide the tile_tokens_dim itself. + # however, when the user chooses a different heuristic for tile_tokens_dim, the autotuner will fail to find the correct cached tactics. + # tile_tokens_dim=tile_tokens_dim, + ) + tunning_config = ( + MoERunner.tuning_config_no_hidden_states_scales + if hidden_states_scale is None + else MoERunner.tuning_config_with_hidden_states_scales + ) + inputs = [ + output, + routing_logits, + topk_ids, + expert_weights, + hidden_states, + ] + if hidden_states_scale is not None: + inputs.append(hidden_states_scale) + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_fp4_block_scale_moe", + [moe_runner], + tunning_config, + inputs, + num_local_experts=num_experts, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm1_bias=gemm1_bias, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + gemm2_bias=gemm2_bias, + output1_scale_scalar=output1_scale_scalar, + output1_scale_gate_scalar=output1_scale_gate_scalar, + output2_scale_scalar=output2_scale_scalar, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=routing_method_type, + enable_pdl=enable_pdl, + do_finalize=do_finalize, + ) + # Call the C++ function for block scale MoE output = moe_op.trtllm_fp4_block_scale_moe( routing_logits, @@ -1147,6 +1517,7 @@ def trtllm_fp4_block_scale_moe_op( do_finalize, enable_pdl, output, + tactic, ) return output @@ -1182,8 +1553,9 @@ def _fake_trtllm_fp4_block_scale_moe( tile_tokens_dim: int, routing_method_type: int, do_finalize: bool, - enable_pdl: Optional[bool] = None, - output: Optional[torch.Tensor] = None, + enable_pdl: bool, + output: Optional[torch.Tensor], + tune_max_num_tokens: int, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1373,6 +1745,7 @@ def trtllm_fp4_block_scale_moe( do_finalize: bool = True, enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 1024, ) -> List[torch.Tensor]: """FP4 block scale MoE operation. @@ -1389,10 +1762,20 @@ def trtllm_fp4_block_scale_moe( Tensor of FC1 weights. Dtype must be uint8 (packed fp4) gemm1_weights_scale (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // (32 if mxfp4 else 16)] Scale tensor of FC1 weights. Dtype must be float8. + gemm1_bias (Optional[torch.Tensor]): shape [num_experts, 2 * intermediate_size] + Tensor of FC1 biases. Dtype is float32. + gemm1_alpha (Optional[torch.Tensor]): shape [num_experts] + Tensor of swiglu alpha. Dtype is float32. + gemm1_beta (Optional[torch.Tensor]): shape [num_experts] + Tensor of swiglu beta. Dtype is float32. + gemm1_clamp_limit (Optional[torch.Tensor]): shape [num_experts] + Tensor of swiglu clamp limit. Dtype is float32. gemm2_weights (torch.Tensor): shape [num_experts, hidden_size, intermediate_size] Tensor of FC2 weights. Dtype must be uint8 (packed fp4) - gemm2_weights_scale (torch.Tensor): shape [num_experts, hidden_size//128, intermediate_size//128] + gemm2_weights_scale (torch.Tensor): shape [num_experts, hidden_size, intermediate_size // (32 if mxfp4 else 16)] Scale tensor of FC2 weights. Dtype must be float8. + gemm2_bias (Optional[torch.Tensor]): shape [num_experts, hidden_size] + Tensor of FC2 biases. Dtype is float32. output1_scale_scalar (Optional[torch.Tensor]): shape [local_num_experts] Tensor of scaling factors for first layer activation output output1_scale_gate_scalar (Optional[torch.Tensor]): shape [local_num_experts] @@ -1415,6 +1798,7 @@ def trtllm_fp4_block_scale_moe( - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. @@ -1454,6 +1838,7 @@ def trtllm_fp4_block_scale_moe( do_finalize, enable_pdl, output, + tune_max_num_tokens, ) @@ -1487,6 +1872,7 @@ def trtllm_fp4_block_scale_routed_moe( do_finalize: bool = True, enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 1024, ) -> List[torch.Tensor]: """FP4 block scale MoE operation. @@ -1531,6 +1917,7 @@ def trtllm_fp4_block_scale_routed_moe( - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 1024) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. @@ -1570,4 +1957,5 @@ def trtllm_fp4_block_scale_routed_moe( do_finalize, enable_pdl, output, + tune_max_num_tokens, ) diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py index 963a8af98..2bb196858 100644 --- a/flashinfer/fused_moe/utils.py +++ b/flashinfer/fused_moe/utils.py @@ -203,11 +203,13 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]: return tuple(num_token_buckets) -def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]: +def get_last_power_of_2_num_tokens_buckets( + max_num_tokens, min_num_tokens=1 +) -> Tuple[int]: max_num_tokens = last_positive_power_of_2(max_num_tokens) num_token_buckets = [] m = max_num_tokens - while m >= 1: + while m >= min_num_tokens: num_token_buckets.append(m) m //= 2 return tuple(num_token_buckets) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index b59a5af89..2996f3382 100755 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -100,9 +100,9 @@ def get_valid_tactics( def forward( self, inputs: List[torch.Tensor], - *, tactic: int = -1, do_preparation: bool = False, + **kwargs, ) -> torch.Tensor: cublas_handle = torch.cuda.current_blas_handle() a, b, scale_a, scale_b, out, workspace_buffer = inputs @@ -398,9 +398,9 @@ def get_valid_tactics( def forward( self, inputs: List[torch.Tensor], - *, tactic: int = -1, do_preparation: bool = False, + **kwargs, ) -> torch.Tensor: a, b, scale_a, scale_b, out, workspace_buffer = inputs module.fp8_gemm.default( @@ -447,8 +447,8 @@ def fp8_gemm_sm100( tuning_config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( - a_tensor_index, - -2, + (a_tensor_index,), + (-2,), get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, ), @@ -489,9 +489,9 @@ def get_valid_tactics( def forward( self, inputs: List[torch.Tensor], - *, tactic: int = -1, do_preparation: bool = False, + **kwargs, ): a, b, a_descale, b_descale, alpha, out, workspace_buffer = inputs module.fp4_gemm.default( @@ -524,8 +524,8 @@ def pad_up(x, y): tuning_config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( - a_tensor_index, - 0, + (a_tensor_index,), + (0,), get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, ), @@ -1421,9 +1421,9 @@ def get_valid_tactics( def forward( self, inputs: List[torch.Tensor], - *, tactic: int = -1, do_preparation: bool = False, + **kwargs, ) -> torch.Tensor: a, b, scale_a, scale_b, out, workspace_buffer = inputs _cudnn_gemm_fp8(workspace_buffer, a, b, scale_a, scale_b, out, out.dtype) @@ -1946,9 +1946,9 @@ def get_valid_tactics( def forward( self, inputs: List[torch.Tensor], - *, tactic: int = -1, do_preparation: bool = False, + **kwargs, ): ( workspace_buffer, @@ -1998,8 +1998,8 @@ def pad_up(x, y): tuning_config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( - a_tensor_index, - 0, + (a_tensor_index,), + (0,), get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, ), diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh index e3ea61eb7..7a6587e06 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh @@ -51,7 +51,8 @@ struct TopKRedType { static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) { auto valueBits = cub::Traits::TwiddleIn( reinterpret_cast::UnsignedBits&>(val)); - TypeCmp compactTmp = reinterpret_cast(valueBits); + TypeCmp compactTmp; + memcpy(&compactTmp, &valueBits, sizeof(valueBits)); compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); // Use 65535 minus idx to give higher priority to elements with smaller indices. return compactTmp; diff --git a/tests/test_trtllm_gen_fused_moe.py b/tests/test_trtllm_gen_fused_moe.py index 35526ccae..a8e07fd48 100644 --- a/tests/test_trtllm_gen_fused_moe.py +++ b/tests/test_trtllm_gen_fused_moe.py @@ -357,19 +357,11 @@ def quantize_inputs( hidden_states, is_swizzling ) hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( - -1 - ) - print( - f"hidden_states.shape: {hidden_states_quant.shape}, dtype: {hidden_states_quant.dtype}" - ) - print( - f"hidden_states_scale.shape: {hidden_states_scale.shape}, dtype: {hidden_states_scale.dtype}" + *hidden_states.shape[:-1], -1 ) return { "hidden_states": hidden_states_quant, - "hidden_states_scale": hidden_states_scale.view( - torch.float8_e4m3fn - ).reshape(-1), + "hidden_states_scale": hidden_states_scale, } elif self.quant_mode == QuantMode.FP4_NVFP4_NVFP4: """Quantize hidden states to NvFP4 format using pre-computed global scale.""" @@ -380,12 +372,13 @@ def quantize_inputs( ) = quant_fp4( hidden_states, hidden_states_scale_global, False, is_swizzling ) + hidden_states_scale_fp4_bytes = hidden_states_scale_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape(*hidden_states.shape[:-1], -1) return { "hidden_states": hidden_states_fp4_bytes, - "hidden_states_scale": hidden_states_scale_fp4_bytes.view( - torch.float8_e4m3fn - ).reshape(-1), + "hidden_states_scale": hidden_states_scale_fp4_bytes, } else: # bf16 return { @@ -1742,9 +1735,9 @@ def cache_permute_indices(): @pytest.mark.parametrize( "moe_impl", [ - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4 x NvFP4"), - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4 x MxFP8"), - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4 x Bf16"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), ],