Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
453 changes: 272 additions & 181 deletions csrc/trtllm_fused_moe_routing_deepseek.cu

Large diffs are not rendered by default.

227 changes: 143 additions & 84 deletions csrc/trtllm_fused_moe_routing_llama4.cu

Large diffs are not rendered by default.

329 changes: 251 additions & 78 deletions csrc/trtllm_fused_moe_routing_renormalize.cu

Large diffs are not rendered by default.

84 changes: 53 additions & 31 deletions include/flashinfer/trtllm/fused_moe/DevKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,39 +112,61 @@ namespace moe::dev {
FLASHINFER_WARN("Unsupported pair"); \
}

#define LAUNCH_ROUTING(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported dtypeExpW"); \
#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, \
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else { \
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}

#define LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag, forceFloatInput) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, true), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, false), kernel, numBlocks, numThreads, \
smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else { \
FLASHINFER_WARN("Unsupported dtypeExpW"); \
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, \
numThreads, smemSize, stream, extraFlag, \
forceFloatInput, numExperts) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else { \
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}

////////////////////////////////////////////////////////////////////////////////////////////////////

#define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag1, numExperts) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Fp32) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} else { \
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading