Skip to content

Commit 066f6cf

Browse files
fix pot. int overflows
1 parent 1120d94 commit 066f6cf

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

ggml/src/ggml-cuda/cross-entropy-loss.cu

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ static __global__ void cross_entropy_loss_f32(
1010
const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
1111
extern __shared__ float tmp[];
1212

13-
logits += blockIdx.x*nclasses;
14-
labels += blockIdx.x*nclasses;
13+
logits += int64_t(blockIdx.x)*nclasses;
14+
labels += int64_t(blockIdx.x)*nclasses;
1515

1616
// Find maximum for softmax:
1717
float max_logit = -INFINITY;
@@ -55,9 +55,9 @@ static __global__ void cross_entropy_loss_back_f32(
5555
float * __restrict__ dst, const int nclasses) {
5656
extern __shared__ float tmp[];
5757

58-
logits += blockIdx.x*nclasses;
59-
labels += blockIdx.x*nclasses;
60-
dst += blockIdx.x*nclasses;
58+
logits += int64_t(blockIdx.x)*nclasses;
59+
labels += int64_t(blockIdx.x)*nclasses;
60+
dst += int64_t(blockIdx.x)*nclasses;
6161

6262
float maxval = -INFINITY;
6363
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
@@ -115,10 +115,10 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
115115

116116
const dim3 blocks_dim(WARP_SIZE, 1, 1);
117117
const dim3 blocks_num(nrows, 1, 1);
118-
const int nbytes_shared = ne00*sizeof(float);
118+
const size_t nbytes_shared = ne00*sizeof(float);
119119

120-
const int id = ggml_cuda_get_device();
121-
const int smpbo = ggml_cuda_info().devices[id].smpbo;
120+
const int id = ggml_cuda_get_device();
121+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
122122

123123
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124124

@@ -169,10 +169,10 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
169169

170170
const dim3 blocks_dim(WARP_SIZE, 1, 1);
171171
const dim3 blocks_num(nrows, 1, 1);
172-
const int nbytes_shared = ne00*sizeof(float);
172+
const size_t nbytes_shared = ne00*sizeof(float);
173173

174-
const int id = ggml_cuda_get_device();
175-
const int smpbo = ggml_cuda_info().devices[id].smpbo;
174+
const int id = ggml_cuda_get_device();
175+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
176176

177177
if (nbytes_shared <= smpbo) {
178178
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))

ggml/src/ggml-cuda/norm.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
55
const int row = blockIdx.x*blockDim.y + threadIdx.y;
66
const int tid = threadIdx.x;
77

8-
x += row*ncols;
9-
dst += row*ncols;
8+
x += int64_t(row)*ncols;
9+
dst += int64_t(row)*ncols;
1010

1111
float2 mean_var = make_float2(0.0f, 0.0f);
1212

@@ -101,8 +101,8 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
101101
const int row = blockIdx.x*blockDim.y + threadIdx.y;
102102
const int tid = threadIdx.x;
103103

104-
x += row*ncols;
105-
dst += row*ncols;
104+
x += int64_t(row)*ncols;
105+
dst += int64_t(row)*ncols;
106106

107107
float tmp = 0.0f; // partial sum for thread in warp
108108

@@ -140,9 +140,9 @@ static __global__ void rms_norm_back_f32(
140140
const int row = blockIdx.x*blockDim.y + threadIdx.y;
141141
const int tid = threadIdx.x;
142142

143-
grad += row*ncols;
144-
xf += row*ncols;
145-
dst += row*ncols;
143+
grad += int64_t(row)*ncols;
144+
xf += int64_t(row)*ncols;
145+
dst += int64_t(row)*ncols;
146146

147147
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
148148
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs

ggml/src/ggml-cuda/softmax.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ static __global__ void soft_max_f32(
2323
const int rowx = blockIdx.x;
2424
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
2525

26-
x += rowx*ncols;
27-
mask += rowy*ncols * (mask != nullptr);
28-
dst += rowx*ncols;
26+
x += int64_t(rowx)*ncols;
27+
mask += int64_t(rowy)*ncols * (mask != nullptr);
28+
dst += int64_t(rowx)*ncols;
2929

3030
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
3131

@@ -124,9 +124,9 @@ static __global__ void soft_max_back_f32(
124124
const int tid = threadIdx.x;
125125
const int rowx = blockIdx.x;
126126

127-
grad += rowx*ncols;
128-
dstf += rowx*ncols;
129-
dst += rowx*ncols;
127+
grad += int64_t(rowx)*ncols;
128+
dstf += int64_t(rowx)*ncols;
129+
dst += int64_t(rowx)*ncols;
130130

131131
float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
132132

0 commit comments

Comments
 (0)