diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 5741611644..25b938b11b 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -32,11 +32,459 @@ namespace flashinfer { namespace btg = batchedGemm::trtllm::gen; +using batchedGemm::gemm::MatrixLayout; using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType; using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; +/* + +Abstraction layers: + +1. TORCH_LIBRARY_FRAGMENT bindings +These are currently the same signature as the public python APIs. +We strive to make the python interface relatively stable +and the naming of parameters meaningful to the users. + +2. FusedMoeLauncher +This performs checks and preparations for the execution, +organized in several stages, see FusedMoeLauncher::run(). + +3. MoE::Runner +Orchestrate and dispatch all the kernels executions to fulfill the requested operation. +This includes PermuteGemm1, Gemm2, activation (if not fused), and finalize. + +4. TrtllmGenBatchedGemmRunner +This provides tactic selection if not determined yet at the public API (or auto-tuning) + +5. BatchedGemm Runner +The low-level gemm kernel executor which is updated together with the kernels. + +6. BatchedGemmInterface +Driver calls take place to carry out the gemm operations. +*/ + +class FusedMoeLauncher { + protected: + Tensor routing_logits{}; + Tensor routing_bias{}; + Tensor hidden_states{}; + Tensor gemm1_weights{}; + Tensor output1_scales_scalar{}; + Tensor output1_scales_gate_scalar{}; + Tensor gemm2_weights{}; + Tensor output2_scales_scalar{}; + + int64_t tile_tokens_dim{}; + int64_t routing_method_type{}; + bool use_shuffled_weight{}; + MatrixLayout weight_layout{MatrixLayout::MajorK}; + + std::tuple device_version; + std::unique_ptr args; + tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; + + btg::Dtype mDtypeAct{btg::Dtype::Bfloat16}; + btg::Dtype mDtypeWeights{btg::Dtype::Bfloat16}; + GatedActType gated_act_type{GatedActType::SwiGlu}; + + // Initialize common data necessary for later. + // May throw exception from TVM_FFI_ICHECK. + void init_common(Tensor const& routing_logits, Optional const& routing_bias, + Tensor const& hidden_states, Tensor const& gemm1_weights, + Tensor const& gemm2_weights, + std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, int64_t gated_act_type); + + // Routing logits [num_tokens, num_experts] + void check_routing_logits_shape() const { + TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits->shape[0], hidden_states->shape[0]) + << "routing_logits and hidden_states must have the same number of tokens."; + TVM_FFI_ICHECK_EQ(routing_logits->shape[1], args->num_experts) + << "routing_logits dim1 must match num_experts."; + } + + // Routing bias [num_experts] + void check_routing_bias_shape() const { + if (routing_bias.defined()) { + TVM_FFI_ICHECK_EQ(routing_bias->ndim, 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias->shape[0], args->num_experts) + << "routing_bias has incorrect shape."; + } + } + + // Hidden states [num_tokens, hidden_size] + void check_hidden_states_shape() const { + TVM_FFI_ICHECK_EQ(hidden_states->ndim, 2) << "hidden_states must be 2D."; + TVM_FFI_ICHECK_EQ(hidden_states->shape[1], args->intermediate_size) + << "hidden_states has incorrect shape."; + } + + // GEMM1 or GEMM2 weights [num_experts, M, K] or [num_experts, K/block_k, M, block_k] + void check_weights_shape(std::string which_weights) const { + Tensor weights{}; + if (which_weights == "gemm1") { + weights = gemm1_weights; + } else if (which_weights == "gemm2") { + weights = gemm2_weights; + } else { + TVM_FFI_LOG_AND_THROW(InternalError) << "Internal error: which_weights = " << which_weights; + } + + int64_t Mn = 0, K = 0; + if (weight_layout == MatrixLayout::MajorK) { + // MajorK [num_experts, M, K] + Mn = weights->shape[1]; + K = weights->shape[2]; + } else if (weight_layout == MatrixLayout::BlockMajorK) { + // BlockMajorK [num_experts, K/block_k, M, block_k] + Mn = weights->shape[2]; + int64_t block_k = weights->shape[3]; + K = weights->shape[1] * block_k; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "Unsupported weight_layout: " << (int)weight_layout; + } + TVM_FFI_ICHECK_EQ(weights->shape[0], args->num_experts) + << which_weights << " weights expert dimension must match num_experts"; + if (which_weights == "gemm1") { + TVM_FFI_ICHECK_EQ(Mn % 2, 0) << which_weights << " weights Mn dimension must be even."; + TVM_FFI_ICHECK_EQ(args->intermediate_size, Mn / 2) + << "intermediate_size has incorrect shape."; + TVM_FFI_ICHECK_EQ(K, hidden_states->shape[1]) + << which_weights << " weights K dimension must be equal to hidden_size."; + } else if (which_weights == "gemm2") { + TVM_FFI_ICHECK_EQ(K, args->intermediate_size) + << which_weights << " weights K dimension must be equal to intermediate_size."; + } + } + + void check_routing_common() const { + TVM_FFI_ICHECK(args->top_k > 0 && args->top_k <= args->num_experts) + << "top_k must be between 1 and num_experts"; + TVM_FFI_ICHECK(args->local_num_experts > 0 && args->local_num_experts <= args->num_experts) + << "local_num_experts must be between 1 and num_experts"; + TVM_FFI_ICHECK(args->local_expert_offset >= 0 && + args->local_expert_offset + args->local_num_experts <= args->num_experts) + << "expert offset and count must be within valid range"; + + check_routing_logits_shape(); + + if (routing_bias.defined()) { + check_routing_bias_shape(); + } + } + + // Routing phase workspace tensors (allocated in prepare_routing() or prepare_routing_common()) + Tensor num_tokens_per_expert; + Tensor total_num_padded_tokens; + Tensor expanded_idx_to_permuted_idx; + Tensor permuted_idx_to_token_idx; + Tensor expert_weights; + Tensor expert_indexes; + Tensor expert_count_histogram; + Tensor cta_idx_xy_to_batch_idx; + Tensor cta_idx_xy_to_mn_limit; + Tensor num_non_exiting_ctas; + + void prepare_routing_common() { + // Allocate routing phase workspace tensors + int32_t max_num_padded_tokens = + tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + + // Common routing workspace tensors allocation + num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, routing_logits->device); + + total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits->device); + + expanded_idx_to_permuted_idx = + alloc_tensor({args->num_tokens * args->top_k}, dl_int32, routing_logits->device); + + permuted_idx_to_token_idx = + alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device); + + expert_indexes = + alloc_tensor({args->num_tokens, args->top_k}, dl_int32, routing_logits->device); + + // expert_weights allocation should be done by derived class since data type could vary + + int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2); + expert_count_histogram = alloc_tensor({size_of_expert_count_histogram}, + dl_int32, // 256 is the max number of threads per block + // and max number of experts + routing_logits->device); + + int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( + args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + + cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); + + cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); + + num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits->device); + + workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); + workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.ProjUpTileN = tile_tokens_dim; + workspace.routing_expert_indexes = static_cast(expert_indexes->data); + workspace.permuted_idx_size = static_cast(total_num_padded_tokens->data); + workspace.expanded_idx_to_permuted_idx = static_cast(expanded_idx_to_permuted_idx->data); + workspace.permuted_idx_to_token_idx = static_cast(permuted_idx_to_token_idx->data); + // workspace.expert_weights will be set by derived class after expert_weights allocation + workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx->data); + workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit->data); + workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas->data); + } + + void check_moe_common() const { + // Hidden states [num_tokens, hidden_size] + TVM_FFI_ICHECK_EQ(hidden_states->ndim, 2) << "hidden_states must be 2D."; + } + + // MoE computation phase workspace tensors (allocated in prepare_moe() or prepare_moe_common()) + Tensor gemm1_output; + Tensor activation_output; + Tensor gemm2_output; + Tensor workspace_fc1; + Tensor workspace_fc2; + Tensor output; + int64_t moe_tactic{-1}; + std::unique_ptr moe_runner; + + void prepare_moe_common(int64_t& moe_tactic) { + using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; + moe_runner = std::make_unique(this->mDtypeAct, this->mDtypeWeights, + args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, + static_cast(this->gated_act_type), + this->use_shuffled_weight, this->weight_layout); + + if (moe_tactic == -1) { + moe_tactic = moe_runner->getDefaultValidConfigIndex( + args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, + args->num_tokens); + } + this->moe_tactic = moe_tactic; + + auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); + workspace_fc1 = alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states->device); + workspace_fc2 = alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states->device); + workspace.bmm1_workspace = workspace_fc1->data; + workspace.bmm2_workspace = workspace_fc2->data; + } + + public: + virtual void check_routing() const = 0; + virtual void prepare_routing() = 0; + virtual void check_moe() const = 0; + virtual void prepare_moe(int64_t& moe_tactic) = 0; + + // Main entry point for all the executions. + // Do initializations prior to calling this as the initializations are different for bf16, fp8 and + // fp4. The executions are non-blocking by default. + Array run(int64_t moe_tactic, bool enable_pdl = true) { + check_routing(); + prepare_routing(); + + // Execute routing + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + cudaStream_t routing_stream = get_stream(routing_logits->device); + routing_runner.run(routing_logits->data, args->routing_bias, args->num_tokens, + args->num_experts, args->top_k, args->n_group, args->topk_group, + args->local_expert_offset, args->local_num_experts, + args->routed_scaling_factor, static_cast(expert_indexes->data), + static_cast(expert_count_histogram->data), + static_cast(total_num_padded_tokens->data), + static_cast(expanded_idx_to_permuted_idx->data), + nullptr /*permuted_idx_to_expanded_idx->data*/, + static_cast(permuted_idx_to_token_idx->data), expert_weights->data, + static_cast(num_tokens_per_expert->data), + static_cast(cta_idx_xy_to_batch_idx->data), + static_cast(cta_idx_xy_to_mn_limit->data), + static_cast(num_non_exiting_ctas->data), args->mDtypeElt, false, true, + static_cast(routing_method_type), routing_stream); + + check_moe(); + // if moe_tactic is -1, it will be set to the default valid config index + prepare_moe(moe_tactic); + + // Execute MoE + cudaStream_t moe_stream = get_stream(hidden_states->device); + moe_runner->run(*args, workspace, hidden_states->device.device_id, moe_stream, moe_tactic, + enable_pdl); + + if (args->do_finalize) { + return {output}; + } + return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx}; + } +}; + +void FusedMoeLauncher::init_common( + Tensor const& routing_logits, Optional const& routing_bias, Tensor const& hidden_states, + Tensor const& gemm1_weights, Tensor const& gemm2_weights, + std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, int64_t gated_act_type) { + // Check devicearchitecture: Blackwell (SM 10.x) required + auto device = hidden_states->device.device_id; + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); + TVM_FFI_ICHECK_EQ(major, 10) << "BF16 MoE requires 10.x architecture. Current device has SM " + << major << minor; + this->device_version = std::make_tuple(major, minor); + + this->routing_logits = routing_logits; + if (routing_bias.has_value()) { + this->routing_bias = routing_bias.value(); + } + this->hidden_states = hidden_states; + this->gemm1_weights = gemm1_weights; + this->gemm2_weights = gemm2_weights; + + args->routing_logits = routing_logits->data; + args->routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; + args->hidden_states = hidden_states->data; + args->gemm1_weights = gemm1_weights->data; + args->gemm2_weights = gemm2_weights->data; + + this->args = std::move(args); + this->tile_tokens_dim = tile_tokens_dim; + this->routing_method_type = routing_method_type; + this->use_shuffled_weight = use_shuffled_weight; + TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) + << "the value of weight_layout is not recognized"; + this->weight_layout = static_cast(weight_layout); + TVM_FFI_ICHECK(0 <= gated_act_type && gated_act_type <= 1) + << "the value of gated_act_type is not recognized"; + this->gated_act_type = static_cast(gated_act_type); +} + +class Bf16MoeLauncher : public FusedMoeLauncher { + public: + Bf16MoeLauncher() = default; + + void init(Tensor const& routing_logits, Optional const& routing_bias, + Tensor const& hidden_states, Tensor const& gemm1_weights, Tensor const& gemm2_weights, + std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout) { + constexpr int64_t gated_act_type = + static_cast(GatedActType::SwiGlu); // not exposed in api for now + + // Do base class init and perform common checks + FusedMoeLauncher::init_common( + routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights, std::move(args), + tile_tokens_dim, routing_method_type, use_shuffled_weight, weight_layout, gated_act_type); + } + + void check_routing() const override { + FusedMoeLauncher::check_routing_common(); + + // TODO n_group, topk_group validation? + } + + void prepare_routing() override { + FusedMoeLauncher::prepare_routing_common(); + + args->mDtypeElt = btg::Dtype::Bfloat16; + args->mDtypeExpW = btg::Dtype::Bfloat16; + args->mUseDeepSeekFp8 = false; + + auto const routing_bias_dtype = dl_bfloat16; + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, routing_bias_dtype, routing_logits->device); + + workspace.expert_weights = expert_weights->data; + } + + void check_moe() const override { + FusedMoeLauncher::check_moe_common(); + + TVM_FFI_ICHECK(weight_layout == MatrixLayout::BlockMajorK) + << "BF16 Moe: weight_layout must be BlockMajorK"; + check_weights_shape("gemm1"); + check_weights_shape("gemm2"); + + TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) + << "the second dimension of weights must be a multiple of 128."; + } + + void prepare_moe(int64_t& moe_tactic) override { + // in the next line moe_tactic is passed by reference so modification will be propagated back + // here + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + int32_t max_num_padded_tokens = workspace.total_max_padded_tokens; + gemm1_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, + hidden_states->device); + activation_output = alloc_tensor({max_num_padded_tokens, args->intermediate_size}, dl_bfloat16, + hidden_states->device); + gemm2_output = alloc_tensor({max_num_padded_tokens, args->hidden_size}, dl_bfloat16, + hidden_states->device); + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = gemm1_output->data; + workspace.gemm1_output_scale = nullptr; // BF16 doesn't use scale tensors + workspace.activation_output = activation_output->data; + workspace.activation_output_scale = nullptr; // BF16 doesn't use scale tensors + workspace.gemm2_output = gemm2_output->data; + workspace.gemm2_output_scale = nullptr; + + output = + alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states->device); + args->output = output->data; + args->output_scale = nullptr; + } +}; + +Tensor trtllm_bf16_moe(Tensor const& routing_logits, Optional const& routing_bias, + Tensor const& hidden_states, Tensor const& gemm1_weights, + Tensor const& gemm2_weights, int64_t num_experts, int64_t top_k, + int64_t n_group, int64_t topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, + int64_t tile_tokens_dim, int64_t routing_method_type, + bool use_shuffled_weight, int64_t weight_layout, int64_t moe_tactic, + bool enable_pdl) { + // Just some basic type validation first and leave more checks to the launcher + TVM_FFI_ICHECK(routing_logits->dtype == dl_float32 || routing_logits->dtype == dl_bfloat16) + << "BF16 MoE: routing_logits must be bfloat16 or float."; + if (routing_bias.has_value()) { + TVM_FFI_ICHECK_EQ(routing_bias.value()->dtype, dl_bfloat16) + << "BF16 MoE: routing_bias must be bfloat16."; + } + TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_bfloat16) + << "BF16 MoE: hidden_states must be bfloat16."; + TVM_FFI_ICHECK_EQ(gemm1_weights->dtype, dl_bfloat16) + << "BF16 MoE: gemm1_weights must be bfloat16."; + TVM_FFI_ICHECK_EQ(gemm2_weights->dtype, dl_bfloat16) + << "BF16 MoE: gemm2_weights must be bfloat16."; + + // Save params to MoE arguments + auto args = std::make_unique(); + args->num_tokens = hidden_states->shape[0]; + args->num_experts = num_experts; + args->hidden_size = hidden_states->shape[1]; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group; + args->topk_group = topk_group; + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + + Bf16MoeLauncher launcher; + launcher.init(routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights, + std::move(args), tile_tokens_dim, routing_method_type, use_shuffled_weight, + weight_layout); + auto data = launcher.run(moe_tactic, enable_pdl)[0]; + return data; +} + Tensor trtllm_fp8_per_tensor_scale_moe_launcher( Tensor routing_logits, Optional routing_bias, Tensor hidden_states, Tensor gemm1_weights, Tensor output1_scales_scalar, Tensor output1_scales_gate_scalar, @@ -290,9 +738,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe( n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type, enable_pdl); - } else { - TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype."; } + + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype."; } void trtllm_fp8_block_scale_moe_launcher( @@ -1128,6 +1576,7 @@ namespace trtllm_cubin_loader { #include } +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_bf16_moe, trtllm_bf16_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_per_tensor_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_moe); diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 931431fa2f..fa69b492af 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -138,7 +138,6 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive; routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive; routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize; - routingData.mPtrScores = routingLogits; // diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 4d344d02bb..ac981cf985 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -76,7 +76,7 @@ def get_available_cubin_files( class ArtifactPath: TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen" TRTLLM_GEN_BMM: str = ( - "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802" + "a72d85b019dc125b9f711300cb989430f762f5a6/batched_gemm-145d1b1-9e1d49a" ) TRTLLM_GEN_GEMM: str = ( "037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e" @@ -91,7 +91,7 @@ class MetaInfoHash: "2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d" ) TRTLLM_GEN_BMM: str = ( - "c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34" + "8c5f97d582df0e4fd9f69ddeb3b72cc3a37915c6c20b4d0905fec69702310b63" ) DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" TRTLLM_GEN_GEMM: str = ( diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index 2759105691..a27267f4bf 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -29,6 +29,7 @@ trtllm_fp4_block_scale_routed_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, + trtllm_bf16_moe, ) __all__ = [ @@ -44,4 +45,5 @@ "trtllm_fp4_block_scale_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", + "trtllm_bf16_moe", ] diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index e6cd1973e9..d6727e0f8e 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1268,6 +1268,81 @@ def refine_tuning_config(cls, tune_max_num_tokens: int): ), ) + @register_custom_op( + "flashinfer::trtllm_bf16_moe", + mutates_args=(""), + ) + def trtllm_bf16_moe_op( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + tile_tokens_dim: int, + routing_method_type: int, + use_shuffled_weight: bool, + weight_layout: int, + moe_tactic: int, + enable_pdl: Optional[bool] = None, + ) -> torch.Tensor: + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) + # Call the C++ function for block scale MoE + output = moe_op.trtllm_bf16_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + tile_tokens_dim, + routing_method_type, + use_shuffled_weight, + weight_layout, + moe_tactic, + enable_pdl, + ) + return output + + @register_fake_op("flashinfer::trtllm_bf16_moe") + def _fake_trtllm_bf16_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + tile_tokens_dim: int, + routing_method_type: int, + use_shuffled_weight: bool, + weight_layout: int, + moe_tactic: int, + enable_pdl: Optional[bool] = None, + ): + seq_len = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + + return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + @register_custom_op( "flashinfer::trtllm_fp8_per_tensor_scale_moe", mutates_args=(""), @@ -1658,12 +1733,57 @@ def _fake_trtllm_fp4_block_scale_moe( return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] return SimpleNamespace( + trtllm_bf16_moe=trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe=trtllm_fp8_per_tensor_scale_moe_op, trtllm_fp8_block_scale_moe=trtllm_fp8_block_scale_moe_op, trtllm_fp4_block_scale_moe=trtllm_fp4_block_scale_moe_op, ) +def trtllm_bf16_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + *, + tile_tokens_dim: int = 8, + routing_method_type: int = 0, + use_shuffled_weight: bool = True, + weight_layout: int = WeightLayout.BlockMajorK, + moe_tactic: int = -1, + enable_pdl: bool = True, +) -> torch.Tensor: + """BF16 block scale MoE operation.""" + return get_trtllm_moe_sm100_module().trtllm_bf16_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm2_weights, + num_experts, + top_k, + n_group or 0, # may receive None from test configs, convert to 0 + topk_group or 0, + intermediate_size, + local_expert_offset, + local_num_experts, + tile_tokens_dim, + routing_method_type, + use_shuffled_weight, + weight_layout, + moe_tactic, + enable_pdl, + ) + + def trtllm_fp8_per_tensor_scale_moe( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h index 629dfc4d27..0edb7661e8 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -506,12 +506,27 @@ class BatchedGemmInterface { throw std::invalid_argument("Invalid combination of options"); } - int32_t const numCtasTile = + if (batchM) { + numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimX); + } else { + numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimY); + } + + int32_t numCtasTile = batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM); + if (batchM) { + numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY); + } else { + numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimX); + } int32_t const numCtasInner = options.mNumSlicesForSplitK; return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner); } + // Creates GemmOptions from kernel and data. + BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, + BatchedGemmData const& data) const; + // Returns the number of CTAs of the current kernel. int32_t getNumCtas(BatchedGemmOptions const& options, std::optional maxNumCtasInBatchDim = std::nullopt) const { @@ -522,10 +537,6 @@ class BatchedGemmInterface { // Returns true if the configuration of the cubin can be executed for the given params. bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const; - // Creates GemmOptions from kernel and data. - BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, - BatchedGemmData const& data) const; - private: // Aligns the pointer to the alignment template @@ -781,6 +792,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa if (result != CUDA_SUCCESS) { return -1; } + // If a module cache has not been given, unload the module to avoid leaking if (!moduleCache.has_value()) { cuModuleUnload(cuModule); diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h index c29fb24b0a..f2a7d5dafd 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h @@ -76,39 +76,43 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { // FIXME We create explicit constructor with all options to WAR stubgen issue in TRT-LLM. BatchedGemmOptions( gemm::AllReduceAlgo allReduceAlgo, gemm::BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, tg::Dtype dtypeB, - tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, bool enablesEarlyExit, - bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, int epilogueLdtmDps, - int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, - bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, - bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, - gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, - int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, - int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, - int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, - bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, - tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, - int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, int tileM, int tileN, - gemm::TileScheduler tileScheduler, bool transposeMmaOutput, bool useCustomMmaSchedule, - bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, bool usePerTokenSfA, - bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, bool useTwoTmaLoadWarps, - bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, gemmGatedAct::ActType actType, - bool clampBeforeAct, std::vector batchedM, std::vector batchedN, - BatchMode batchMode, int numBatches, bool isStaticBatch, int numTokens, RouteImpl routeImpl, - bool gridWaitForPrimaryRouting, bool fusedAct, int numRegsPerThreadNonEpilogueWarp, - int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt) + int clusterDimY, int clusterDimZ, gemm::CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, + tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, + int epilogueTileN, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, + bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, + bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, + gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, + int mmaM, int mmaN, bool mockAllReduce, int n, int numRegsCastAWarps, + int numRegsCopySfLdsSttm, int numRegsPerThreadEpilogueWarp, + int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, + int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, + int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, + std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, + tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, gemm::SplitK splitK, int tileK, + int tileM, int tileN, gemm::TileScheduler tileScheduler, bool transposeMmaOutput, + bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useHoistTryWaitForCustomMmaSchedule, + bool usePerTokenSfA, bool usePerTokenSfB, bool useShuffledMatrixA, bool useTmaStore, + bool useTwoTmaLoadWarps, bool useTwoMmaWarps, bool useUnrollLoop2xForMma, int worldSize, + gemmGatedAct::ActType actType, bool clampBeforeAct, std::vector batchedM, + std::vector batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch, + int numTokens, RouteImpl routeImpl, bool gridWaitForPrimaryRouting, bool fusedAct, + bool useTmaOobOpt) : gemmGatedAct::GemmGatedActOptions( gemm::GemmOptions( - allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, - dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, - enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits, - epilogueTileM, epilogueTileN, gridTriggerSecondaryA, gridTriggerSecondaryB, - gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB, - hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, - mmaKind, mmaM, mmaN, mockAllReduce, n, numSlicesForSplitK, numSlicesForSliceK, - numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, - numStagesWorkId, outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, - sfLayoutC, sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, + allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, + ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, + enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, + epilogueLdtmBits, epilogueTileM, epilogueTileN, gridTriggerSecondaryA, + gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, + gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits, + layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numRegsCastAWarps, + numRegsCopySfLdsSttm, numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, + numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma, + numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, + outputDebugTensors, patchF2fp, sfBlockSizeA, sfLayoutA, sfLayoutB, sfLayoutC, + sfReshapeFactor, sliceK, splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, useHoistTryWaitForCustomMmaSchedule, usePerTokenSfA, usePerTokenSfB, useShuffledMatrixA, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps, @@ -121,9 +125,6 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting), mIsStaticBatch(isStaticBatch), mNumBatches(numBatches), - mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), - mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), - mNumRegsCastAWarps(numRegsCastAWarps), mNumTokens(numTokens), mRouteImpl(routeImpl), mUseTmaOobOpt(useTmaOobOpt) {} @@ -143,12 +144,6 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions { bool mIsStaticBatch{true}; // Number of Gemm batches. int mNumBatches; - // Number of registers per thread for non-epilogue warps - int mNumRegsPerThreadNonEpilogueWarp{0}; - // Number of registers per thread for epilogue warps - int mNumRegsPerThreadEpilogueWarp{0}; - // Number of registers for the cast A warps. - int mNumRegsCastAWarps{0}; // Total number of tokens. int mNumTokens{32}; // Whether load the input tokens and do routing. @@ -269,16 +264,8 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw } } - if (doesRouteImplUseTma(options.mRouteImpl)) { - TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N."); - - if (tg::mmaKindIsBlockFmt(options.mMmaKind)) { - auto dtypeRoute = batchM ? options.mDtypeA : options.mDtypeB; - TLLM_CHECK_ERROR(options.mTileK % tg::dtypeNumEltsPerSf(dtypeRoute) == 0, - "tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)."); - TLLM_CHECK_ERROR(options.mTileK % (tg::dtypeNumEltsPerSf(dtypeRoute) * 16) == 0, - "tileK needs to be a multiple of 16 * tg::dtypeNumEltsPerSf(dtypeA)."); - } + if (options.mClusterDimX > 1) { + TLLM_CHECK_ERROR(!batchM, "2CTA Gemm currently only supports batch N."); } if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl)) { @@ -323,6 +310,7 @@ struct BatchedGemmConfig { char const* mHash{nullptr}; #else trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + int32_t mInstanceIdx{0}; #endif BatchedGemmOptions mOptions; @@ -345,11 +333,6 @@ inline std::string dumpOptions(BatchedGemmOptions const& options) { << std::endl; ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl; ss << "mFusedAct=" << options.mFusedAct << "," << std::endl; - ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," - << std::endl; - ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," - << std::endl; - ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl; return ss.str(); } diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h index 6f2b1c270d..e9d5a23a65 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h @@ -97,6 +97,23 @@ enum class TileScheduler { //////////////////////////////////////////////////////////////////////////////////////////////////// +enum class CtaSwizzleType : uint32_t { + // Rasterize CTAs along the M dimension. + RasterizeAlongM = 0, + // Rasterize CTAs along the N dimension. + RasterizeAlongN, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 2. + ZigZagAlongM2, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 2. + ZigZagAlongN2, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 4. + ZigZagAlongM4, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 4. + ZigZagAlongN4, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Helper functions to check the SplitK type. #define SPLIT_K_FUNCTION(Mode) \ diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h index 1086cd4fd5..e796bcc23c 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h @@ -191,6 +191,7 @@ struct GemmGatedActConfig { char const* mHash{nullptr}; #else trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + int32_t mInstanceIdx{0}; #endif GemmGatedActOptions mOptions{}; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h index f9c7044700..8710da2a85 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h @@ -92,16 +92,18 @@ struct GemmOptions { GemmOptions() = default; GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, - tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, - bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, - int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, - bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, - MatrixLayout layoutA, MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, - int mmaM, int mmaN, bool mockAllReduce, int n, int numSlicesForSplitK, - int numSlicesForSliceK, int numStages, int numStagesMma, + int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, + tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, + int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, + bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, + bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA, + MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, + bool mockAllReduce, int n, int numRegsCastAWarps, int numRegsCopySfLdsSttm, + int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, + int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, @@ -117,6 +119,7 @@ struct GemmOptions { mClusterDimX{clusterDimX}, mClusterDimY{clusterDimY}, mClusterDimZ{clusterDimZ}, + mCtaSwizzleType{ctaSwizzleType}, mDtypeAcc{dtypeAcc}, mDtypeA{dtypeA}, mDtypeB{dtypeB}, @@ -148,6 +151,10 @@ struct GemmOptions { mMmaN{mmaN}, mMockAllReduce{mockAllReduce}, mN{n}, + mNumRegsCastAWarps(numRegsCastAWarps), + mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm), + mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), + mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), mNumSlicesForSplitK{numSlicesForSplitK}, mNumSlicesForSliceK{numSlicesForSliceK}, mNumStages{numStages}, @@ -193,6 +200,8 @@ struct GemmOptions { int mClusterDimY{1}; // Cluster size in Z dim. int mClusterDimZ{1}; + // The type of CTA swizzle. + CtaSwizzleType mCtaSwizzleType{CtaSwizzleType::RasterizeAlongM}; // Data type of the accumulators. tg::Dtype mDtypeAcc{tg::Dtype::Fp32}; // Data type of the A matrix. @@ -263,6 +272,14 @@ struct GemmOptions { bool mMockAllReduce{false}; // The N dimension of GEMM. int mN{64 * 4}; + // Number of registers for the cast A warps. + int mNumRegsCastAWarps{0}; + // Number of registers for the LDS+STTM warps. + int mNumRegsCopySfLdsSttm{0}; + // Number of registers per thread for epilogue warps + int mNumRegsPerThreadEpilogueWarp{0}; + // Number of registers per thread for non-epilogue warps + int mNumRegsPerThreadNonEpilogueWarp{0}; // Number of partitions along the K dimension. When mNumSlicesForSplitK > 1, // the problem is distributed across several SMs, where each CTA works on its local K slice. // Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA) @@ -369,6 +386,7 @@ struct GemmConfig { char const* mHash{nullptr}; #else trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + int32_t mInstanceIdx{0}; #endif GemmOptions mOptions{}; @@ -409,6 +427,8 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mClusterDimX=" << options.mClusterDimX << "," << std::endl; ss << "mClusterDimY=" << options.mClusterDimY << "," << std::endl; ss << "mClusterDimZ=" << options.mClusterDimZ << "," << std::endl; + ss << "mCtaSwizzleType=" << "gemm::CtaSwizzleType(" + << static_cast(options.mCtaSwizzleType) << ")" << "," << std::endl; ss << "mDtypeAcc=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeAcc) << ")" << "," << std::endl; ss << "mDtypeA=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeA) << ")" << "," @@ -449,6 +469,12 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mMmaN=" << options.mMmaN << "," << std::endl; ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl; ss << "mN=" << options.mN << "," << std::endl; + ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; + ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl; + ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," + << std::endl; + ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," + << std::endl; ss << "mNumSlicesForSplitK=" << options.mNumSlicesForSplitK << "," << std::endl; ss << "mNumSlicesForSliceK=" << options.mNumSlicesForSliceK << "," << std::endl; ss << "mNumStages=" << options.mNumStages << "," << std::endl; @@ -673,18 +699,27 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 || options.mDtypeC == tg::Dtype::MxE4m3) && options.mMmaM != 128) { - // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. - int newTileM = 128 * divUp(options.mTileM, 128); - TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, - ") for MmaKind=", gemm::toString(options.mMmaKind), - ". Setting MmaM to 128 and TileM to ", newTileM); - if (updateOptions) { - options.mMmaM = 128; - options.mTileM = newTileM; + if (options.mClusterDimX == 1) { + // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. + int newTileM = 128 * divUp(options.mTileM, 128); + TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, + ") for MmaKind=", gemm::toString(options.mMmaKind), + ". Setting MmaM to 128 and TileM to ", newTileM); + if (updateOptions) { + options.mMmaM = 128; + options.mTileM = newTileM; + } else { + return false; + } } else { - return false; + TLLM_CHECK_ERROR(options.mMmaM == 256 && options.mTileM == 128, + "2CTA UTCxMMA only supports mmaM = 256 and tileM = 128."); } } + if (options.mClusterDimX > 1) { + TLLM_CHECK_ERROR(options.mLayoutB != MatrixLayout::BlockMajorK, + "layoutB == MatrixLayout::BlockMajorK is not supported for now"); + } if (options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); @@ -869,14 +904,26 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (!options.mSliceK) { - TLLM_CHECK_ERROR(options.mMmaM <= options.mEpilogueTileM, + TLLM_CHECK_ERROR(options.mMmaM / options.mClusterDimX <= options.mEpilogueTileM, "EpilogueTileM must be larger or equal than mmaM."); + } else { + // FIXME: this is not necessary limitation. Simply fixing num repeats in TmemSliceKA should be + // enough. + TLLM_CHECK_ERROR((options.mTileN & (options.mTileN - 1)) == 0, + "For Slice-K TileN is required to be a power of 2"); + } + + if (options.mClusterDimX == 2) { + TLLM_CHECK_ERROR(options.mMmaM == 256, "Only mmaM = 256 is supported for 2CTA UTCMMA."); + TLLM_CHECK_ERROR(options.mMmaN % 16 == 0, "mmaN needs to be multiple of 16 for 2CTA UTCMMA."); } + TLLM_CHECK_ERROR( options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); - TLLM_CHECK_ERROR(options.mClusterDimX == 1 && options.mClusterDimY == 1, - "GEMM does not support cluster in X and Y dimensions."); + TLLM_CHECK_ERROR( + (options.mClusterDimX == 1 || options.mClusterDimX == 2) && options.mClusterDimY == 1, + "GEMM does not support cluster in X and Y dimensions."); TLLM_CHECK_ERROR(options.mClusterDimZ == 1 || options.mNumSlicesForSplitK > 1, "Cluster DimZ is only allowed for split-k."); TLLM_CHECK_ERROR(options.mTileM <= 128, "GEMM does not support TileM > 128."); @@ -1003,6 +1050,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Non-DeepSeekFp8 requires persistent scheduler when using numStagesMma >1"); } } + if (options.mUseDeepSeekFp8) { + TLLM_CHECK_ERROR(options.mClusterDimX == 1, "2CTA Gemm is not supported for DeepSeekFp8"); + } if (options.mUseDeepSeekFp8) { TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, "A and B dtype must be E4m3 for DeepSeek Fp8. Found dtypeA=", @@ -1259,7 +1309,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, - options.mUsePerTokenSfA, options.mUsePerTokenSfB, options.mBiasType); + options.mUsePerTokenSfA, options.mUsePerTokenSfB, + /* useTwoCtas*/ options.mClusterDimX == 2, options.mBiasType); } return true; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h index eba3f54737..9cb615c750 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h @@ -31,7 +31,6 @@ namespace batchedGemm { namespace batchedGemm { - //////////////////////////////////////////////////////////////////////////////////////////////////// // TODO: Find a better header to put this in, that we can include from here. @@ -137,12 +136,11 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in std::vector shape = {static_cast(hiddenSize), static_cast(numTokens)}; if (useTmaOobOpt /* also implies input/output activation */) { - // If TMA OOB optimization is used, we use 3D logical shape (M, tileM, K) or (N, tileN, K). - // The outer dimension is extended to make room for the possible counterbalance positive - // offset from the middle "bound" dimension. The counterbalance should be no more than - // ctaTileNumTokens. + // If TMA OOB optimization is used: + // Shape [hidden, tokens] Stride [1, hidden] becomes + // Shape [hidden, tileN, TmaDimMax, TmaDimMax] Stride [1, hidden, XLargeN - hidden, hidden] shape = {static_cast(hiddenSize), static_cast(ctaTileNumTokens), - static_cast(numTokens + ctaTileNumTokens)}; + static_cast(tg::TmaDimMax), static_cast(tg::TmaDimMax)}; } else if (isWeights) { // If the matrix is a weights matrix, we use 3D logical shape (B, M, K) or (B, N, K). shape = {static_cast(hiddenSize), static_cast(numTokens), @@ -153,7 +151,8 @@ static auto makeTmaShapeStrideAbc(GemmOptions const& options, int mM, int mN, in // Swap the first two dimension as mentioned before. std::vector stride = {1, static_cast(hiddenSize)}; if (useTmaOobOpt) { - stride = {1, static_cast(hiddenSize), static_cast(hiddenSize)}; + stride = {1, static_cast(hiddenSize), static_cast(tg::XLargeN - hiddenSize), + static_cast(hiddenSize)}; } else if (isWeights) { stride = {1, static_cast(hiddenSize), static_cast(hiddenSize) * static_cast(numTokens)}; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h index 16b4af3149..e11374739f 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h @@ -29,54 +29,6 @@ struct KernelParams { // Maximum number of CTAs in the batch-token dimension. static constexpr int MaxNumCtas = 2048; - // NOTE: TMA out-of-bounds optimization for MoE padded tokens: - // - // Originally the padded tokens is a 2D tensor [hiddenDim, ctaGridDimY * tileN] with stride [1, - // hiddenDim] and box size [tileM, tileN] at pointer p. We waste bandwidth bytes since we only - // want to load [0, batchEnd) out of the [0, tileN) box size: batchEnd is a runtime variable while - // box size needs to be fixed at compile time. - // - // To deal with this, we reshape the tensor to 3D: [hiddenDim, tileN, ctaGridDimY * tileN] with - // stride [1, hiddenDim, hiddenDim] and box size [tileM, tileN, 1]. For the original 2D - // tensor, - // - // Offset Coords [ : , ctaIdxY * tileN ], - // Box Sizes [ : , tileN ], - // Coords Range [ : , ctaIdxY * tileN : ctaIdxY * tileN + tileN], - // - // while we only want load the range [ctaIdxY * tileN, ctaIdxY * tileN + batchEnd), 1 <= batchEnd - // <= tileN - // - // For the reshaped 3D tensor, - // - // Offset Coords [ : , tileN - batchEnd , - // ctaIdxY * tileN + batchEnd ], - // Box Sizes [ : , tileN , - // 1 ], - // Coords Range [ : , tileN - batchEnd : min(tileN, 2 * tileN - batchEnd), - // ctaIdxY * tileN + batchEnd : ctaIdx * tileN + batchEnd + 1], - // - // while min(tileN, 2 * tileN - batchEnd) always evaluates to tileN. The unwanted tokens are - // essentially filtered out by utilizing the OOB feature of TMA. Since the 2nd and 3rd dimension - // has the same stride, we end up loading the following (adding the left and right end of the 2nd - // and 3rd dimension ranges): - // - // Effective 2D Coords Range - // [ : , tileN + ctaIdxY * tileN : tileN + ctaIdxY * tileN + batchEnd], - // - // This is exactly the same as the original range except for the offset tileN, thus we also need - // to offset the pointer in the opposite direction: - // - // Ptr (p) -> Ptr (p - tileN * hiddenDim) - // - // Due to the restrictions of TMA unit, the above operations requires the TMA descriptor and the - // underlying buffer be constructed differently: - // - Requires valid buffer at (p - tileN * hidden) - needs prepending `tileN` tokens. - // - TMA outermost dimension must be extended by `tileN` or loads will OOB in the rightmost side. - // The latter is because when batchEnd == tileN, the offset coords in the 3rd dimension becomes - // ctaIdxY * tileN + tileN. When ctaIdxY = ctaGridDimY - 1, it becomes ((ctaGridDimY - 1) * tileN - // + tileN = ctaGridDimY * tileN which is equal to the 3rd dimension size and will be filtered - // out. That's why we need to extend the tensor size by tileN. // // TMA descriptor for A. // Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h index 640b3a69f0..4d79f83c23 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h @@ -163,7 +163,7 @@ class KernelTraits { int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, - bool usePerTokenSfB, BiasType biasType) + bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) : mMmaKind{mmaKind} { // // SMEM @@ -213,8 +213,8 @@ class KernelTraits { // LoadB { // Number of bytes in load B shared memory. - auto const numSmemBytesLoadB = - numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; + auto const numSmemBytesLoadB = numStages * (useTwoCtas ? tileN / 2 : tileN) * tileK * + getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h index a1412444ae..c7b18af138 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h @@ -156,7 +156,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + ss << "Error: Failed to initialize the TMA descriptor. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; @@ -251,7 +251,7 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl; + ss << "Error: Failed to initialize the TMA descriptor for SF. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; diff --git a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h index 393949a516..53155c8ffb 100644 --- a/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h @@ -23,6 +23,21 @@ namespace gen { //////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA OOB optimization constants. +// +// CUDA Programming Guide states that "globalDim must be non-zero and less than or equal to 2^32". +// In practice, the kernel acts funny with TMA shape of 2^32 so we use 2^31. +constexpr unsigned long TmaDimMax = 1UL << 31; +// Chosen so that LargeN * XLargeN * sizeof(dtype) >= 2^64 which causes overflow and effectively +// becomes 0. As sizeof(dtype) can be as small as 0.5B, we choose LargeN = 2^30 and XLargeN = 2^35 +// so overflow can happen. +constexpr unsigned long LargeN = 1UL << 30; +// Used in TMA stride. Should be less than 2^40. +constexpr unsigned long XLargeN = 1UL << 35; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 880c739259..74994c8e77 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -38,6 +38,7 @@ from flashinfer.fused_moe import ( WeightLayout, convert_to_block_layout, + trtllm_bf16_moe, trtllm_fp4_block_scale_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, @@ -219,6 +220,7 @@ class QuantMode(IntEnum): FP4_MXFP4_Bf16 = 3 FP8_BLOCK_SCALE = 4 FP8_PER_TENSOR = 5 + BF16 = 6 # ==================================================================================== @@ -983,6 +985,143 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} +# ==================================================================================== +# BF16 Implementation +# ==================================================================================== + + +class BF16Moe(Moe): + """BF16 MoE implementation.""" + + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): + """No scaling for weights.""" + return { + "hidden_states_scale_global": None, + "gemm1_weights": gemm1_weights.to(torch.bfloat16), + "gemm1_scales": None, + "gemm1_scales_global": None, + "gemm2_weights": gemm2_weights.to(torch.bfloat16), + "gemm2_scales": None, + "gemm2_scales_global": None, + } + + def quantize_inputs(self, hidden_states, *unused_args): + """No scaling for hidden states.""" + return { + "hidden_states": hidden_states.to(torch.bfloat16), + "hidden_states_scale": None, + } + + def prepare_static_weights_for_kernel( + self, + args_dequant, + args, + gemm1_weights_orig, + gemm2_weights_orig, + hidden_size, + intermediate_size, + num_experts, + weight_processing, + ): + """Prepare quantized weights for kernel (done offline with weights).""" + + # Use shuffled weights with BlockMajorK layout for better performance + use_shuffled_weight = weight_processing["use_shuffled_weight"] + weight_layout = weight_processing["layout"] + + if use_shuffled_weight: + # FIXME: this depends on the kernel internals + epilogue_tile_m = 128 + + # Reorder rows of W1 for fused gated activation + gemm1_weights_bf16_shuffled = [] + gemm2_weights_bf16_shuffled = [] + for i in range(num_experts): + tmp_weights1 = reorder_rows_for_gated_act_gemm( + args.gemm1_weights[i].clone().view(torch.uint8) + ) + tmp_weights1 = shuffle_matrix_a(tmp_weights1, epilogue_tile_m) + tmp_weights2 = shuffle_matrix_a( + args.gemm2_weights[i].clone().view(torch.uint8), epilogue_tile_m + ) + + if weight_layout == WeightLayout.BlockMajorK: + block_k = 128 + tmp_weights1 = convert_to_block_layout( + tmp_weights1.view(torch.uint8), block_k + ) + tmp_weights2 = convert_to_block_layout( + tmp_weights2.view(torch.uint8), block_k + ) + + gemm1_weights_bf16_shuffled.append(tmp_weights1.view(torch.bfloat16)) + gemm2_weights_bf16_shuffled.append(tmp_weights2.view(torch.bfloat16)) + + # Stack weights for all experts + gemm1_weights_bf16_shuffled = ( + torch.stack(gemm1_weights_bf16_shuffled) + .view(torch.bfloat16) + .contiguous() + ) + gemm2_weights_bf16_shuffled = ( + torch.stack(gemm2_weights_bf16_shuffled) + .view(torch.bfloat16) + .contiguous() + ) + + return { + "gemm1_weights": gemm1_weights_bf16_shuffled, + "gemm2_weights": gemm2_weights_bf16_shuffled, + "use_shuffled_weight": use_shuffled_weight, + "weight_layout": weight_layout, + } + + def call_moe( + self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs + ): + """Call MoE with runtime input quantization + kernel execution (done at runtime).""" + expert_logits = kwargs["expert_logits"] + routing_bias = kwargs["routing_bias"] + num_experts = kwargs["num_experts"] + top_k = kwargs["top_k"] + n_groups = kwargs["n_groups"] + top_k_groups = kwargs["top_k_groups"] + intermediate_size = kwargs["intermediate_size"] + routing_method_type = kwargs["routing_method_type"] + tile_tokens_dim = kwargs["tile_tokens_dim"] + + output = trtllm_bf16_moe( + expert_logits, # float + routing_bias, + hidden_states_orig, + static_data["gemm1_weights"], + static_data["gemm2_weights"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + # the rest are enforced by the api to be passed in the keyword form + # as opposed to the positional form + use_shuffled_weight=static_data["use_shuffled_weight"], + weight_layout=static_data["weight_layout"], + tile_tokens_dim=tile_tokens_dim, + routing_method_type=routing_method_type, + ) + + return output.to(torch.float) + + def compute_reference(self, args): + """BF16 reference implementation.""" + return run_moe_reference_bf16(args) + + def get_tolerances(self): + """Get BF16 accuracy tolerances.""" + return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} + + # ==================================================================================== # Quantizer Factory # ==================================================================================== @@ -1276,8 +1415,6 @@ def check_accuracy(a, b, atol, rtol, percent): count = torch.sum(left > right) mismatch_percent = count / a.numel() if mismatch_percent > 1 - percent: - print(a) - print(b) raise Exception( f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " f"(threshold: {1 - percent:.4f})" @@ -1584,6 +1721,9 @@ def run_moe_dequant(args, quant_mode: QuantMode): .to(torch.float) ) args.c_global_sf = 1.0 + elif quant_mode == QuantMode.BF16: + activation_output = activation_output.to(torch.bfloat16).to(torch.float) + args.c_global_sf = 1.0 else: # mxfp4Bf16 activation_output = activation_output.to(torch.bfloat16).to(torch.float) args.c_global_sf = 1.0 @@ -1789,6 +1929,37 @@ def run_moe_reference_per_tensor_scale_fp8(args): return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant +def run_moe_reference_bf16(args): + """BF16 reference implementation.""" + + # no scaling for hidden states and weights + hidden_states_dequant = args.hidden_states.to(torch.float) + gemm1_weights_dequant = {} + for i in range(args.num_experts): + gemm1_weights_dequant[i] = args.gemm1_weights[i].to(torch.float) + gemm2_weights_dequant = {} + for i in range(args.num_experts): + gemm2_weights_dequant[i] = args.gemm2_weights[i].to(torch.float) + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + GatedActType.SwiGlu.value, # gated_act_type + ) + + return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant + + def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): """Unified actual computation that delegates to implementation-specific methods.""" # 1. Prepare static weights for the kernel (offline processing) @@ -1848,6 +2019,7 @@ def cache_permute_indices(): pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + pytest.param(BF16Moe(), id="All_BF16"), ], ) @pytest.mark.parametrize( @@ -1914,7 +2086,12 @@ def cache_permute_indices(): "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + "compatible_moe_impls": [ + FP4Moe, + FP8PerTensorMoe, + FP8BlockScaleMoe, + BF16Moe, + ], }, id="Renorm", marks=pytest.mark.skip( @@ -1931,7 +2108,7 @@ def cache_permute_indices(): "routed_scaling": None, "has_routing_bias": False, "routing_method_type": RoutingMethodType.RenormalizeNaive, - "compatible_moe_impls": [FP4Moe], + "compatible_moe_impls": [FP4Moe, BF16Moe], }, id="RenormNaive", ), @@ -1988,7 +2165,7 @@ def cache_permute_indices(): { "use_shuffled_weight": True, "layout": WeightLayout.BlockMajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], + "compatible_moe_impls": [FP8BlockScaleMoe, BF16Moe], }, id="Shuffled_BlockMajorK", ), @@ -2024,31 +2201,34 @@ def test_moe_quantization_classes( if compute_capability[0] in [11, 12]: pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.") # Skip incompatible combinations + # NOTE Don't worry about the short-lived variables between if & skip - formatting simplifications if gated_act_type == GatedActType.GeGlu and ( type(moe_impl) is not FP4Moe or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): + routing = routing_config["routing_method_type"].name # GeGlu is only supported for FP4Moe FP4_NVFP4_NVFP4 and TopK routing pytest.skip( - f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" + f"Incompatible: {moe_impl.name} {gated_act_type.name} {routing=} {num_tokens=}" ) elif gated_act_type == GatedActType.SwiGlu and ( hidden_size > 1024 or intermediate_size > 1024 ): # Skip some tests for SwiGlu for testing speed pytest.skip( - f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" + f"Skip for testing speed: {gated_act_type.name} {hidden_size=} {intermediate_size=}" ) if type(moe_impl) not in routing_config["compatible_moe_impls"]: - pytest.skip( - f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}" - ) + routing = routing_config["routing_method_type"].name + pytest.skip(f"Incompatible routing: {moe_impl.name} {routing=}") if type(moe_impl) not in weight_processing["compatible_moe_impls"]: + layout = weight_processing["layout"].name + use_shuffled_weight = weight_processing["use_shuffled_weight"] pytest.skip( - f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" + f"Incompatible weight format: {moe_impl.name} {use_shuffled_weight=} {layout=}" ) # TODO(jimmzhou): enable MxFP4xBf16 on SM103