Skip to content

Commit f47c2f1

Browse files
committed
macro cleanup
1 parent 4e57346 commit f47c2f1

File tree

1 file changed

+22
-41
lines changed

1 file changed

+22
-41
lines changed

accelerated_scan/warp.cuh

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -222,54 +222,35 @@ __global__ void scan(
222222
}
223223
}
224224

225+
#define DISPATCH_SCAN_INNER(TupleT, backward, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
226+
scan<TupleT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, backward><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
227+
reinterpret_cast<const TupleT *>(gates.data_ptr<torch_weight_t>()), \
228+
reinterpret_cast<const TupleT *>(tokens.data_ptr<torch_weight_t>()), \
229+
reinterpret_cast<TupleT *>(out.data_ptr<torch_weight_t>()), \
230+
reinterpret_cast<const TupleT *>(output), \
231+
reinterpret_cast<TupleT *>(gateGradOut), \
232+
batch_stride, dim_stride, reverse \
233+
);
234+
225235
#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
226236
using AlignedT = AlignedTuple<weight_t, kNStepsPerThread>; \
227237
using UnalignedT = UnalignedTuple<weight_t, kNStepsPerThread>; \
228-
if (!output) { \
229-
if (kNStepsPerThread == 4 && \
230-
((long)gates.data_ptr()) % 16 == 0 && \
231-
((long)tokens.data_ptr()) % 16 == 0 && \
232-
((long)out.data_ptr()) % 16 == 0) { \
233-
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, false><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
234-
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
235-
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
236-
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
237-
nullptr, nullptr, \
238-
batch_stride, dim_stride, reverse \
239-
); \
238+
if (kNStepsPerThread == 4 && \
239+
((long)gates.data_ptr()) % 16 == 0 && \
240+
((long)tokens.data_ptr()) % 16 == 0 && \
241+
((long)out.data_ptr()) % 16 == 0 && \
242+
((long)output) % 16 == 0 && \
243+
((long)gateGradOut) % 16 == 0) { \
244+
if (output) { \
245+
DISPATCH_SCAN_INNER(AlignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
240246
} else { \
241-
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, false><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
242-
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
243-
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
244-
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
245-
nullptr, nullptr, \
246-
batch_stride, dim_stride, reverse \
247-
); \
247+
DISPATCH_SCAN_INNER(AlignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
248248
} \
249249
} else { \
250-
if (kNStepsPerThread == 4 && \
251-
((long)gates.data_ptr()) % 16 == 0 && \
252-
((long)tokens.data_ptr()) % 16 == 0 && \
253-
((long)out.data_ptr()) % 16 == 0 && \
254-
((long)output) % 16 == 0 && \
255-
((long)gateGradOut) % 16 == 0) { \
256-
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, true><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
257-
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
258-
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
259-
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
260-
reinterpret_cast<const AlignedT *>(output), \
261-
reinterpret_cast<AlignedT *>(gateGradOut), \
262-
batch_stride, dim_stride, reverse \
263-
); \
250+
if (output) { \
251+
DISPATCH_SCAN_INNER(UnalignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
264252
} else { \
265-
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, true><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
266-
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
267-
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
268-
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
269-
reinterpret_cast<const UnalignedT *>(output), \
270-
reinterpret_cast<UnalignedT *>(gateGradOut), \
271-
batch_stride, dim_stride, reverse \
272-
); \
253+
DISPATCH_SCAN_INNER(UnalignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
273254
} \
274255
}
275256

0 commit comments

Comments
 (0)