@@ -21,20 +21,20 @@ namespace operators {
21
21
22
22
enum GroupNormKernelFlags { kHasScale = 1 , kHasBias = 2 };
23
23
24
- #define CHECK_CASE (i, flags, kernel_name, args ... ) \
25
- if (i == flags) { \
26
- kernel_name<T, i><<<grid, threads, 0 , dev_ctx.stream()>>> (args ); \
24
+ #define CHECK_CASE (i, flags, kernel_name, ...) \
25
+ if (i == flags) { \
26
+ kernel_name<T, i><<<grid, threads, 0 , dev_ctx.stream()>>> (__VA_ARGS__ ); \
27
27
}
28
28
29
29
// 0 for no scale, no bias
30
30
// 1 for has scale, no bias
31
31
// 2 for no scale, has bias
32
32
// 3 for has scale, has bias
33
- #define UNROLL_ALL_CASES (flags, kernel_name, args ... ) \
34
- CHECK_CASE (0 , flags, kernel_name, args) \
35
- CHECK_CASE (1 , flags, kernel_name, args) \
36
- CHECK_CASE (2 , flags, kernel_name, args) \
37
- CHECK_CASE (3 , flags, kernel_name, args )
33
+ #define UNROLL_ALL_CASES (flags, kernel_name, ...) \
34
+ CHECK_CASE (0 , flags, kernel_name, __VA_ARGS__) \
35
+ CHECK_CASE (1 , flags, kernel_name, __VA_ARGS__) \
36
+ CHECK_CASE (2 , flags, kernel_name, __VA_ARGS__) \
37
+ CHECK_CASE (3 , flags, kernel_name, __VA_ARGS__ )
38
38
39
39
template <typename T>
40
40
__device__ __inline__ void CudaAtomicAddWithWarp (T* sum, T value) {
0 commit comments