Skip to content

Commit a8aa791

Browse files
author
wopeizl
authored
Merge pull request #15453 from wopeizl/fix15313
fix pr 15313
2 parents 7f8b40f + e6a3a3a commit a8aa791

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

paddle/fluid/operators/group_norm_op.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@ namespace operators {
2121

2222
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
2323

24-
#define CHECK_CASE(i, flags, kernel_name, args...) \
25-
if (i == flags) { \
26-
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(args); \
24+
#define CHECK_CASE(i, flags, kernel_name, ...) \
25+
if (i == flags) { \
26+
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
2727
}
2828

2929
// 0 for no scale, no bias
3030
// 1 for has scale, no bias
3131
// 2 for no scale, has bias
3232
// 3 for has scale, has bias
33-
#define UNROLL_ALL_CASES(flags, kernel_name, args...) \
34-
CHECK_CASE(0, flags, kernel_name, args) \
35-
CHECK_CASE(1, flags, kernel_name, args) \
36-
CHECK_CASE(2, flags, kernel_name, args) \
37-
CHECK_CASE(3, flags, kernel_name, args)
33+
#define UNROLL_ALL_CASES(flags, kernel_name, ...) \
34+
CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \
35+
CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \
36+
CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \
37+
CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
3838

3939
template <typename T>
4040
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {

0 commit comments

Comments
 (0)