Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cstring>
#include <vector>

#include "flashinfer/trtllm/batched_gemm/KernelRunner.h"
Expand Down Expand Up @@ -115,7 +116,14 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
}
}

FLASHINFER_CHECK(!mPassingConfigIndices.empty(), "No kernel found for the given options");
FLASHINFER_CHECK(
!mPassingConfigIndices.empty(),
"No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, "
"mUseDeepSeekFp8: %d, "
"mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d",
tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(),
tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput,
mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize);
}

size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(
Expand Down Expand Up @@ -367,6 +375,7 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(

return false;
};

// Sort configs by options.
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpFunc);
Expand Down
42 changes: 28 additions & 14 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ void trtllm_fp8_block_scale_moe_launcher(
int32_t max_num_padded_tokens =
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
args.num_tokens, top_k, num_experts, tile_tokens_dim);
int32_t max_num_padded_tokens_gemm1 =
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt));
int32_t max_num_padded_tokens_gemm2 =
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut));
Comment on lines +396 to +401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for calculating max_num_padded_tokens_gemm1 and max_num_padded_tokens_gemm2 is duplicated in trtllm_fp4_block_scale_moe_launcher (lines 774-779). In fact, the entire function trtllm_fp8_block_scale_moe_launcher is very similar to trtllm_fp4_block_scale_moe_launcher. To improve maintainability and reduce redundancy, consider refactoring the common logic into a templated helper function.

Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits->device);
Tensor expanded_idx_to_permuted_idx =
alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device);
Expand All @@ -413,16 +419,16 @@ void trtllm_fp8_block_scale_moe_launcher(
// dl_float8_e4m3fn, hidden_states->device);
// Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size},
// dl_float8_e4m3fn, hidden_states->device);
Tensor gemm1_output =
alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, hidden_states->device);
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8,
hidden_states->device);
Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens},
dl_float32, hidden_states->device);
Tensor activation_output =
alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states->device);
Tensor activation_output_scale = alloc_tensor({intermediate_size / 128, max_num_padded_tokens},
dl_float32, hidden_states->device);
Tensor gemm2_output =
alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device);
Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size},
dl_uint8, hidden_states->device);
Tensor activation_output_scale = alloc_tensor(
{intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states->device);
Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16,
hidden_states->device);

int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
Expand Down Expand Up @@ -519,7 +525,8 @@ void trtllm_fp8_block_scale_moe_launcher(

// setup workspace
workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens->data);
workspace.total_max_padded_tokens = max_num_padded_tokens;
workspace.total_max_padded_tokens =
std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
workspace.ProjUpTileN = tile_tokens_dim;
workspace.routing_expert_indexes = static_cast<int*>(expert_indexes->data);
workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens->data);
Expand Down Expand Up @@ -764,6 +771,12 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
int32_t max_num_padded_tokens =
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
args.num_tokens, top_k, num_experts, tile_tokens_dim);
int32_t max_num_padded_tokens_gemm1 =
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt));
int32_t max_num_padded_tokens_gemm2 =
tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount(
max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut));
Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states->device);
Tensor expanded_idx_to_permuted_idx =
alloc_tensor({args.num_tokens, args.top_k}, dl_int32, hidden_states->device);
Expand All @@ -788,20 +801,20 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
// Tensor gemm1_output = alloc_tensor(
// {max_num_padded_tokens, gemm1_output_hidden},
// dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, hidden_states->device);
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, gemm1_output_hidden},
Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden},
dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8,
hidden_states->device);

Optional<Tensor> gemm1_output_scale = std::nullopt;
if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) {
int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens,
int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens_gemm1,
intermediate_size / sf_vec_size);
// gemm1_output_scale = alloc_tensor({sf_size}, dl_float8_e4m3fn, hidden_states->device);
gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states->device);
}

Tensor gemm2_output =
alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device);
Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16,
hidden_states->device);

int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
Expand Down Expand Up @@ -958,7 +971,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(

// setup workspace
workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens->data);
workspace.total_max_padded_tokens = max_num_padded_tokens;
workspace.total_max_padded_tokens =
std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2);
workspace.ProjUpTileN = tile_tokens_dim;
workspace.routing_expert_indexes = static_cast<int*>(expert_indices->data);
workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens->data);
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_available_cubin_files(
class ArtifactPath:
TRTLLM_GEN_FMHA: str = "7206d64e67f4c8949286246d6e2e07706af5d223/fmha/trtllm-gen"
TRTLLM_GEN_BMM: str = (
"e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802"
"56fea80cb22f8b2ef2a2c6a822a075fb20b36803/batched_gemm-074aec4-cc00b23"
)
TRTLLM_GEN_GEMM: str = (
"037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e"
Expand All @@ -91,7 +91,7 @@ class MetaInfoHash:
"2f605255e71d673768f5bece66dde9e2e9f4c873347bfe8fefcffbf86a3c847d"
)
TRTLLM_GEN_BMM: str = (
"c98b4ce69a39fd41556d67033c30ea814ef76b0a2fe16e798e55baf0104acc34"
"4a8ceeb356fc5339021acf884061e97e49e01da5c75dbf0f7cf4932c37a70152"
)
DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48"
TRTLLM_GEN_GEMM: str = (
Expand Down
16 changes: 9 additions & 7 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,9 @@ def __init__(
self.gated_act_type = gated_act_type
self.tile_tokens_dim = tile_tokens_dim

def get_tile_tokens_dim(self, num_tokens: int, top_k: int):
def get_tile_tokens_dim(
self, num_tokens: int, top_k: int, max_tile_tokens_dim: int = 128
):
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
Expand All @@ -910,10 +912,10 @@ def get_tile_tokens_dim(self, num_tokens: int, top_k: int):
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

if num_tokens_per_expert > 128 and num_tokens_per_expert < 256:
tile_tokens_dim = 192
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
return tile_tokens_dim
Comment on lines 914 to 919
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating tile_tokens_dim from num_tokens_per_expert is duplicated here and in flashinfer.utils.calculate_tile_tokens_dim. To avoid code duplication and improve maintainability, you could extract this common logic into a new helper function in flashinfer.utils.

For example:

# in flashinfer/utils.py
def _calculate_tile_dim_from_tokens_per_expert(num_tokens_per_expert: int, max_tile_tokens_dim: int = 128) -> int:
    tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
    if 128 < num_tokens_per_expert < 256:
        tile_tokens_dim = 192
    tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
    return tile_tokens_dim

Then both calculate_tile_tokens_dim and get_tile_tokens_dim can call this helper function after calculating their respective num_tokens_per_expert.


def get_valid_tactics(
Expand All @@ -931,7 +933,7 @@ def get_valid_tactics(
) = inputs
num_tokens = routing_logits.shape[0]
tile_tokens_dim = (
self.get_tile_tokens_dim(num_tokens, self.top_k)
self.get_tile_tokens_dim(num_tokens, self.top_k, 128)
if self.tile_tokens_dim is None
else self.tile_tokens_dim
)
Expand Down Expand Up @@ -975,7 +977,7 @@ def forward(
) = inputs
num_tokens = routing_logits.shape[0]
tile_tokens_dim = (
self.get_tile_tokens_dim(num_tokens, self.top_k)
self.get_tile_tokens_dim(num_tokens, self.top_k, 128)
if self.tile_tokens_dim is None
else self.tile_tokens_dim
)
Expand Down
11 changes: 7 additions & 4 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,18 @@ def next_positive_power_of_2(x: int) -> int:
return n + 1


def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, top_k: int) -> int:
def calculate_tile_tokens_dim(
num_tokens: int, num_experts: int, top_k: int, max_tile_tokens_dim: int = 128
) -> int:
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = num_tokens * top_k // num_experts

# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

if num_tokens_per_expert > 128 and num_tokens_per_expert < 256:
tile_tokens_dim = 192
# Cap to 8-max_tile_tokens_dim tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), max_tile_tokens_dim)
return tile_tokens_dim


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,19 @@ class BatchedGemmInterface {
throw std::invalid_argument("Invalid combination of options");
}

int32_t const numCtasTile =
if (batchM) {
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimX);
} else {
numCtasBatch = gemm::divUpMul(numCtasBatch, options.mClusterDimY);
}

int32_t numCtasTile =
batchM ? gemm::divUp(options.mN, options.mTileN) : gemm::divUp(options.mM, options.mTileM);
if (batchM) {
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimY);
} else {
numCtasTile = gemm::divUpMul(numCtasTile, options.mClusterDimX);
}
int32_t const numCtasInner = options.mNumSlicesForSplitK;
return std::make_tuple(numCtasBatch, numCtasTile, numCtasInner);
}
Expand Down
Loading