Skip to content

Commit 27cd43c

Browse files
CUDA: Fix integer reductions by removing +/-INF initialization (#3200)
* fix(cuda): fix integer reduction initialization Replace hardcoded INFINITY/-INFINITY values with type-safe template functions for reduction initialization. Using floating-point infinity values with integer types causes undefined behavior and crashes on newer GPU architectures like Blackwell. The new template specializations use appropriate numeric_limits values for integer types while preserving the original behavior for floating-point types. * fix(cuda): replace limits import with cuda std equivalents --------- Co-authored-by: ivarflakstad <[email protected]>
1 parent 3390caa commit 27cd43c

File tree

1 file changed

+69
-9
lines changed

1 file changed

+69
-9
lines changed

candle-kernels/src/reduce.cu

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,63 @@
11
#include "cuda_utils.cuh"
22
#include <cmath>
33
#include <stdint.h>
4+
#include <cuda/std/limits>
45

56
#define WARP_SIZE 32
67
const int BLOCK_SIZE = 1024;
78

9+
// Helpers to initialize reduction identities for both floating-point and
10+
// integer types. For floats we keep using +/-INFINITY, while for integers
11+
// we use well-defined numeric_limits values instead of relying on casting
12+
// +/-INFINITY to an integer type (which is undefined behaviour and has been
13+
// observed to break on newer GPU architectures such as Blackwell).
14+
template <typename T>
15+
__device__ __forceinline__ T reduce_init_lowest() {
16+
// Default implementation is used for floating-point types (__half,
17+
// __nv_bfloat16, float, double). The conversion from -INFINITY (double)
18+
// to these types is well-defined and produces -inf.
19+
return -INFINITY;
20+
}
21+
22+
template <typename T>
23+
__device__ __forceinline__ T reduce_init_highest() {
24+
// Default implementation is used for floating-point types (__half,
25+
// __nv_bfloat16, float, double). The conversion from INFINITY (double)
26+
// to these types is well-defined and produces +inf.
27+
return INFINITY;
28+
}
29+
30+
// Integer specializations – use numeric_limits instead of +/-INFINITY.
31+
template <>
32+
__device__ __forceinline__ int64_t reduce_init_lowest<int64_t>() {
33+
return ::cuda::std::numeric_limits<int64_t>::lowest();
34+
}
35+
36+
template <>
37+
__device__ __forceinline__ uint32_t reduce_init_lowest<uint32_t>() {
38+
return ::cuda::std::numeric_limits<uint32_t>::lowest();
39+
}
40+
41+
template <>
42+
__device__ __forceinline__ uint8_t reduce_init_lowest<uint8_t>() {
43+
return ::cuda::std::numeric_limits<uint8_t>::lowest();
44+
}
45+
46+
template <>
47+
__device__ __forceinline__ int64_t reduce_init_highest<int64_t>() {
48+
return ::cuda::std::numeric_limits<int64_t>::max();
49+
}
50+
51+
template <>
52+
__device__ __forceinline__ uint32_t reduce_init_highest<uint32_t>() {
53+
return ::cuda::std::numeric_limits<uint32_t>::max();
54+
}
55+
56+
template <>
57+
__device__ __forceinline__ uint8_t reduce_init_highest<uint8_t>() {
58+
return ::cuda::std::numeric_limits<uint8_t>::max();
59+
}
60+
861
// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
962
// but also expect a f32 output so that this can be used for normalization e.g.
1063
// in softmax.
@@ -102,29 +155,29 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta,
102155

103156
if (alpha == nullptr && beta == nullptr) {
104157
for (int col = tid; col < ncols; col += block_size) {
105-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
158+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
106159
dst[row*ncols + col] = static_cast<T>(lhs);
107160
}
108161
}
109162
else if (alpha == nullptr && beta != nullptr) {
110163
for (int col = tid; col < ncols; col += block_size) {
111164
float b = static_cast<float>(beta[col]);
112-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
165+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
113166
dst[row*ncols + col] = static_cast<T>(lhs + b);
114167
}
115168
}
116169
else if (alpha != nullptr && beta == nullptr) {
117170
for (int col = tid; col < ncols; col += block_size) {
118171
float a = static_cast<float>(alpha[col]);
119-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
172+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
120173
dst[row*ncols + col] = static_cast<T>(lhs * a);
121174
}
122175
}
123176
else {
124177
for (int col = tid; col < ncols; col += block_size) {
125178
float a = static_cast<float>(alpha[col]);
126179
float b = static_cast<float>(beta[col]);
127-
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
180+
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
128181
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
129182
}
130183
}
@@ -301,7 +354,9 @@ fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
301354
size_t tid = threadIdx.x;
302355
size_t dst_id = blockIdx.x;
303356

304-
shr[tid] = -INFINITY;
357+
// Initialize with the lowest representable value for T so that the first
358+
// comparison in the reduction always picks a real element.
359+
shr[tid] = reduce_init_lowest<T>();
305360
// Elements summed in this block range from dst_id * el_to_sum_per_block
306361
// to (dst_id + 1) * el_to_sum_per_block.
307362
size_t start_idx = dst_id * el_to_sum_per_block;
@@ -339,7 +394,9 @@ fast_min(const size_t src_numel, const size_t el_to_sum_per_block,
339394
size_t tid = threadIdx.x;
340395
size_t dst_id = blockIdx.x;
341396

342-
shr[tid] = INFINITY;
397+
// Initialize with the highest representable value for T so that the first
398+
// comparison in the reduction always picks a real element.
399+
shr[tid] = reduce_init_highest<T>();
343400
// Elements summed in this block range from dst_id * el_to_sum_per_block
344401
// to (dst_id + 1) * el_to_sum_per_block.
345402
size_t start_idx = dst_id * el_to_sum_per_block;
@@ -378,8 +435,9 @@ fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block,
378435
size_t tid = threadIdx.x;
379436
size_t dst_id = blockIdx.x;
380437

381-
// Not sure how that works on uint32_t and uint8_t but it seems to do ok.
382-
shr[tid] = INFINITY;
438+
// For floating types this uses +inf; for integer types we use the largest
439+
// representable value instead of casting INFINITY to an integer.
440+
shr[tid] = reduce_init_highest<T>();
383441
shr_index[tid] = 0xFFFFFFFF;
384442
bool not_set = true;
385443
// Elements summed in this block range from dst_id * el_to_sum_per_block
@@ -427,7 +485,9 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
427485
size_t tid = threadIdx.x;
428486
size_t dst_id = blockIdx.x;
429487

430-
shr[tid] = -INFINITY;
488+
// For floating types this uses -inf; for integer types we use the lowest
489+
// representable value instead of casting -INFINITY to an integer.
490+
shr[tid] = reduce_init_lowest<T>();
431491
shr_index[tid] = 0xFFFFFFFF;
432492
bool not_set = true;
433493
// Elements summed in this block range from dst_id * el_to_sum_per_block

0 commit comments

Comments
 (0)