@@ -222,54 +222,35 @@ __global__ void scan(
222222 }
223223}
224224
225+ #define DISPATCH_SCAN_INNER (TupleT, backward, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse ) \
226+ scan<TupleT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , backward><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
227+ reinterpret_cast <const TupleT *>(gates.data_ptr<torch_weight_t >()), \
228+ reinterpret_cast <const TupleT *>(tokens.data_ptr<torch_weight_t >()), \
229+ reinterpret_cast <TupleT *>(out.data_ptr<torch_weight_t >()), \
230+ reinterpret_cast <const TupleT *>(output), \
231+ reinterpret_cast <TupleT *>(gateGradOut), \
232+ batch_stride, dim_stride, reverse \
233+ );
234+
225235#define DISPATCH_SCAN (weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse ) \
226236 using AlignedT = AlignedTuple<weight_t , kNStepsPerThread >; \
227237 using UnalignedT = UnalignedTuple<weight_t , kNStepsPerThread >; \
228- if (!output) { \
229- if (kNStepsPerThread == 4 && \
230- ((long )gates.data_ptr ()) % 16 == 0 && \
231- ((long )tokens.data_ptr ()) % 16 == 0 && \
232- ((long )out.data_ptr ()) % 16 == 0 ) { \
233- scan<AlignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , false ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
234- reinterpret_cast <const AlignedT *>(gates.data_ptr <torch_weight_t >()), \
235- reinterpret_cast <const AlignedT *>(tokens.data_ptr <torch_weight_t >()), \
236- reinterpret_cast <AlignedT *>(out.data_ptr <torch_weight_t >()), \
237- nullptr , nullptr , \
238- batch_stride, dim_stride, reverse \
239- ); \
238+ if (kNStepsPerThread == 4 && \
239+ ((long )gates.data_ptr()) % 16 == 0 && \
240+ ((long )tokens.data_ptr()) % 16 == 0 && \
241+ ((long )out.data_ptr()) % 16 == 0 && \
242+ ((long )output) % 16 == 0 && \
243+ ((long )gateGradOut) % 16 == 0 ) { \
244+ if (output) { \
245+ DISPATCH_SCAN_INNER (AlignedT, true , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
240246 } else { \
241- scan<UnalignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , false ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
242- reinterpret_cast <const UnalignedT*>(gates.data_ptr <torch_weight_t >()), \
243- reinterpret_cast <const UnalignedT*>(tokens.data_ptr <torch_weight_t >()), \
244- reinterpret_cast <UnalignedT *>(out.data_ptr <torch_weight_t >()), \
245- nullptr , nullptr , \
246- batch_stride, dim_stride, reverse \
247- ); \
247+ DISPATCH_SCAN_INNER (AlignedT, false , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
248248 } \
249249 } else { \
250- if (kNStepsPerThread == 4 && \
251- ((long )gates.data_ptr ()) % 16 == 0 && \
252- ((long )tokens.data_ptr ()) % 16 == 0 && \
253- ((long )out.data_ptr ()) % 16 == 0 && \
254- ((long )output) % 16 == 0 && \
255- ((long )gateGradOut) % 16 == 0 ) { \
256- scan<AlignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , true ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
257- reinterpret_cast <const AlignedT *>(gates.data_ptr <torch_weight_t >()), \
258- reinterpret_cast <const AlignedT *>(tokens.data_ptr <torch_weight_t >()), \
259- reinterpret_cast <AlignedT *>(out.data_ptr <torch_weight_t >()), \
260- reinterpret_cast <const AlignedT *>(output), \
261- reinterpret_cast <AlignedT *>(gateGradOut), \
262- batch_stride, dim_stride, reverse \
263- ); \
250+ if (output) { \
251+ DISPATCH_SCAN_INNER (UnalignedT, true , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
264252 } else { \
265- scan<UnalignedT, kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , true ><<<grid, kNThreads , kNWarpsPerBlock * sizeof (weight_t ) * 2 , stream>>> ( \
266- reinterpret_cast <const UnalignedT*>(gates.data_ptr <torch_weight_t >()), \
267- reinterpret_cast <const UnalignedT*>(tokens.data_ptr <torch_weight_t >()), \
268- reinterpret_cast <UnalignedT *>(out.data_ptr <torch_weight_t >()), \
269- reinterpret_cast <const UnalignedT *>(output), \
270- reinterpret_cast <UnalignedT *>(gateGradOut), \
271- batch_stride, dim_stride, reverse \
272- ); \
253+ DISPATCH_SCAN_INNER (UnalignedT, false , weight_t , kNStepsPerThread , kNThreadsPerWarp , kNWarpsPerBlock , kNChunksPerSequence , grid, kNThreads , stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
273254 } \
274255 }
275256
0 commit comments