Skip to content

Commit 4c9fb49

Browse files
committed
Commit more files for increase supported topK and num experts in deepseek routing for nemotron
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
1 parent e610194 commit 4c9fb49

File tree

2 files changed

+60
-50
lines changed

2 files changed

+60
-50
lines changed

include/flashinfer/trtllm/fused_moe/DevKernel.h

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -169,56 +169,65 @@ namespace moe::dev {
169169
FLASHINFER_WARN("Unsupported dtypeExpW"); \
170170
}
171171

172-
#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
173-
smemSize, stream, extraFlag, numExperts) \
174-
if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
175-
data.mDtypeExpW == tg::Dtype::Fp32) { \
176-
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \
177-
numBlocks, numThreads, smemSize, stream); \
178-
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
179-
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
180-
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \
181-
kernel, numBlocks, numThreads, smemSize, stream); \
182-
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
183-
data.mDtypeExpW == tg::Dtype::Fp32) { \
184-
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \
185-
kernel, numBlocks, numThreads, smemSize, stream); \
186-
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
187-
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
188-
LAUNCH_TILEN(data, coopLaunch, \
189-
LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \
190-
numBlocks, numThreads, smemSize, stream); \
191-
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
192-
data.mDtypeExpW == tg::Dtype::Fp32) { \
193-
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \
194-
kernel, numBlocks, numThreads, smemSize, stream); \
195-
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
196-
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
197-
LAUNCH_TILEN(data, coopLaunch, \
198-
LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \
199-
numBlocks, numThreads, smemSize, stream); \
200-
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
201-
data.mDtypeExpW == tg::Dtype::Fp32) { \
202-
LAUNCH_TILEN(data, coopLaunch, \
203-
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \
204-
numBlocks, numThreads, smemSize, stream); \
205-
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
206-
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
207-
LAUNCH_TILEN(data, coopLaunch, \
208-
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \
209-
kernel, numBlocks, numThreads, smemSize, stream); \
210-
} else { \
211-
FLASHINFER_WARN("Unsupported dtypeExpW"); \
172+
#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
173+
smemSize, stream, extraFlag, numExperts, \
174+
numTopExperts) \
175+
if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
176+
data.mDtypeExpW == tg::Dtype::Fp32) { \
177+
LAUNCH_TILEN(data, coopLaunch, \
178+
LAUNCH_ESC(float, float, float, numExperts, numTopExperts, extraFlag), kernel, \
179+
numBlocks, numThreads, smemSize, stream); \
180+
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
181+
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
182+
LAUNCH_TILEN(data, coopLaunch, \
183+
LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \
184+
kernel, numBlocks, numThreads, smemSize, stream); \
185+
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
186+
data.mDtypeExpW == tg::Dtype::Fp32) { \
187+
LAUNCH_TILEN(data, coopLaunch, \
188+
LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \
189+
kernel, numBlocks, numThreads, smemSize, stream); \
190+
} else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
191+
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
192+
LAUNCH_TILEN( \
193+
data, coopLaunch, \
194+
LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \
195+
kernel, numBlocks, numThreads, smemSize, stream); \
196+
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
197+
data.mDtypeExpW == tg::Dtype::Fp32) { \
198+
LAUNCH_TILEN(data, coopLaunch, \
199+
LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, numTopExperts, extraFlag), \
200+
kernel, numBlocks, numThreads, smemSize, stream); \
201+
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
202+
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
203+
LAUNCH_TILEN( \
204+
data, coopLaunch, \
205+
LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \
206+
kernel, numBlocks, numThreads, smemSize, stream); \
207+
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
208+
data.mDtypeExpW == tg::Dtype::Fp32) { \
209+
LAUNCH_TILEN( \
210+
data, coopLaunch, \
211+
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \
212+
kernel, numBlocks, numThreads, smemSize, stream); \
213+
} else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
214+
data.mDtypeExpW == tg::Dtype::Bfloat16) { \
215+
LAUNCH_TILEN(data, coopLaunch, \
216+
LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, \
217+
numTopExperts, extraFlag), \
218+
kernel, numBlocks, numThreads, smemSize, stream); \
219+
} else { \
220+
FLASHINFER_WARN("Unsupported dtypeExpW"); \
212221
}
213222

214-
#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
215-
stream, extraFlag, numExperts) \
216-
if (extraFlag) { \
217-
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
218-
smemSize, stream, true, numExperts); \
219-
} else { \
220-
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
221-
smemSize, stream, false, numExperts); \
223+
#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
224+
stream, extraFlag, numExperts, numTopExperts) \
225+
if (extraFlag) { \
226+
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
227+
smemSize, stream, true, numExperts, numTopExperts); \
228+
} else { \
229+
LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \
230+
smemSize, stream, false, numExperts, numTopExperts); \
222231
}
223232

224233
////////////////////////////////////////////////////////////////////////////////////////////////////

include/flashinfer/trtllm/fused_moe/RoutingKernel.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,15 @@ struct Data : public DataBase {
176176
bool mUseRoutingSoftmax;
177177
};
178178

179-
template <typename InputT_, typename BiasT_, typename OutputT_, int MaxNumExperts_, bool UseGroups_,
180-
bool isPow2_, bool UsePdl_>
179+
template <typename InputT_, typename BiasT_, typename OutputT_, int MaxNumExperts_,
180+
int MaxNumTopExperts_, bool UseGroups_, bool isPow2_, bool UsePdl_>
181181
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, MaxNumExperts_, isPow2_, UsePdl_> {
182182
using InputT = InputT_;
183183
using BiasT = BiasT_;
184184
using OutputT = OutputT_;
185185

186186
static constexpr bool UseGroups = UseGroups_;
187+
static constexpr int MaxNumTopExperts = MaxNumTopExperts_;
187188

188189
PackedScoreIdx<OutputT>* mPtrTopKPacked = nullptr;
189190

0 commit comments

Comments
 (0)