@@ -169,56 +169,65 @@ namespace moe::dev {
169169 FLASHINFER_WARN (" Unsupported dtypeExpW" ); \
170170 }
171171
172- #define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
173- smemSize, stream, extraFlag, numExperts) \
174- if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
175- data.mDtypeExpW == tg::Dtype::Fp32) { \
176- LAUNCH_TILEN (data, coopLaunch, LAUNCH_ESC (float , float , float , numExperts, extraFlag), kernel, \
177- numBlocks, numThreads, smemSize, stream); \
178- } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
179- data.mDtypeExpW == tg::Dtype::Bfloat16) { \
180- LAUNCH_TILEN (data, coopLaunch, LAUNCH_ESC (float , float , __nv_bfloat16, numExperts, extraFlag), \
181- kernel, numBlocks, numThreads, smemSize, stream); \
182- } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
183- data.mDtypeExpW == tg::Dtype::Fp32) { \
184- LAUNCH_TILEN (data, coopLaunch, LAUNCH_ESC (float , __nv_bfloat16, float , numExperts, extraFlag), \
185- kernel, numBlocks, numThreads, smemSize, stream); \
186- } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
187- data.mDtypeExpW == tg::Dtype::Bfloat16) { \
188- LAUNCH_TILEN (data, coopLaunch, \
189- LAUNCH_ESC (float , __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \
190- numBlocks, numThreads, smemSize, stream); \
191- } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
192- data.mDtypeExpW == tg::Dtype::Fp32) { \
193- LAUNCH_TILEN (data, coopLaunch, LAUNCH_ESC (__nv_bfloat16, float , float , numExperts, extraFlag), \
194- kernel, numBlocks, numThreads, smemSize, stream); \
195- } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
196- data.mDtypeExpW == tg::Dtype::Bfloat16) { \
197- LAUNCH_TILEN (data, coopLaunch, \
198- LAUNCH_ESC (__nv_bfloat16, float , __nv_bfloat16, numExperts, extraFlag), kernel, \
199- numBlocks, numThreads, smemSize, stream); \
200- } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
201- data.mDtypeExpW == tg::Dtype::Fp32) { \
202- LAUNCH_TILEN (data, coopLaunch, \
203- LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, float , numExperts, extraFlag), kernel, \
204- numBlocks, numThreads, smemSize, stream); \
205- } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
206- data.mDtypeExpW == tg::Dtype::Bfloat16) { \
207- LAUNCH_TILEN (data, coopLaunch, \
208- LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \
209- kernel, numBlocks, numThreads, smemSize, stream); \
210- } else { \
211- FLASHINFER_WARN (" Unsupported dtypeExpW" ); \
172+ #define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
173+ smemSize, stream, extraFlag, numExperts, \
174+ numTopExperts) \
175+ if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
176+ data.mDtypeExpW == tg::Dtype::Fp32) { \
177+ LAUNCH_TILEN (data, coopLaunch, \
178+ LAUNCH_ESC (float , float , float , numExperts, numTopExperts, extraFlag), kernel, \
179+ numBlocks, numThreads, smemSize, stream); \
180+ } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \
181+ data.mDtypeExpW == tg::Dtype::Bfloat16) { \
182+ LAUNCH_TILEN (data, coopLaunch, \
183+ LAUNCH_ESC (float , float , __nv_bfloat16, numExperts, numTopExperts, extraFlag), \
184+ kernel, numBlocks, numThreads, smemSize, stream); \
185+ } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
186+ data.mDtypeExpW == tg::Dtype::Fp32) { \
187+ LAUNCH_TILEN (data, coopLaunch, \
188+ LAUNCH_ESC (float , __nv_bfloat16, float , numExperts, numTopExperts, extraFlag), \
189+ kernel, numBlocks, numThreads, smemSize, stream); \
190+ } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
191+ data.mDtypeExpW == tg::Dtype::Bfloat16) { \
192+ LAUNCH_TILEN ( \
193+ data, coopLaunch, \
194+ LAUNCH_ESC (float , __nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \
195+ kernel, numBlocks, numThreads, smemSize, stream); \
196+ } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
197+ data.mDtypeExpW == tg::Dtype::Fp32) { \
198+ LAUNCH_TILEN (data, coopLaunch, \
199+ LAUNCH_ESC (__nv_bfloat16, float , float , numExperts, numTopExperts, extraFlag), \
200+ kernel, numBlocks, numThreads, smemSize, stream); \
201+ } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \
202+ data.mDtypeExpW == tg::Dtype::Bfloat16) { \
203+ LAUNCH_TILEN ( \
204+ data, coopLaunch, \
205+ LAUNCH_ESC (__nv_bfloat16, float , __nv_bfloat16, numExperts, numTopExperts, extraFlag), \
206+ kernel, numBlocks, numThreads, smemSize, stream); \
207+ } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
208+ data.mDtypeExpW == tg::Dtype::Fp32) { \
209+ LAUNCH_TILEN ( \
210+ data, coopLaunch, \
211+ LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, float , numExperts, numTopExperts, extraFlag), \
212+ kernel, numBlocks, numThreads, smemSize, stream); \
213+ } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \
214+ data.mDtypeExpW == tg::Dtype::Bfloat16) { \
215+ LAUNCH_TILEN (data, coopLaunch, \
216+ LAUNCH_ESC (__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, \
217+ numTopExperts, extraFlag), \
218+ kernel, numBlocks, numThreads, smemSize, stream); \
219+ } else { \
220+ FLASHINFER_WARN (" Unsupported dtypeExpW" ); \
212221 }
213222
214- #define LAUNCH_ROUTING_DEEPSEEK_IMPL (data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
215- stream, extraFlag, numExperts) \
216- if (extraFlag) { \
217- LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
218- smemSize, stream, true , numExperts); \
219- } else { \
220- LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
221- smemSize, stream, false , numExperts); \
223+ #define LAUNCH_ROUTING_DEEPSEEK_IMPL (data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
224+ stream, extraFlag, numExperts, numTopExperts) \
225+ if (extraFlag) { \
226+ LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
227+ smemSize, stream, true , numExperts, numTopExperts); \
228+ } else { \
229+ LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG (data, coopLaunch, kernel, numBlocks, numThreads, \
230+ smemSize, stream, false , numExperts, numTopExperts); \
222231 }
223232
224233// //////////////////////////////////////////////////////////////////////////////////////////////////
0 commit comments