@@ -10,8 +10,8 @@ static __global__ void cross_entropy_loss_f32(
1010 const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
1111 extern __shared__ float tmp[];
1212
13- logits += blockIdx .x *nclasses;
14- labels += blockIdx .x *nclasses;
13+ logits += int64_t ( blockIdx .x ) *nclasses;
14+ labels += int64_t ( blockIdx .x ) *nclasses;
1515
1616 // Find maximum for softmax:
1717 float max_logit = -INFINITY;
@@ -55,9 +55,9 @@ static __global__ void cross_entropy_loss_back_f32(
5555 float * __restrict__ dst, const int nclasses) {
5656 extern __shared__ float tmp[];
5757
58- logits += blockIdx .x *nclasses;
59- labels += blockIdx .x *nclasses;
60- dst += blockIdx .x *nclasses;
58+ logits += int64_t ( blockIdx .x ) *nclasses;
59+ labels += int64_t ( blockIdx .x ) *nclasses;
60+ dst += int64_t ( blockIdx .x ) *nclasses;
6161
6262 float maxval = -INFINITY;
6363 for (int i = threadIdx .x ; i < nclasses; i += WARP_SIZE) {
@@ -115,10 +115,10 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
115115
116116 const dim3 blocks_dim (WARP_SIZE, 1 , 1 );
117117 const dim3 blocks_num (nrows, 1 , 1 );
118- const int nbytes_shared = ne00*sizeof (float );
118+ const size_t nbytes_shared = ne00*sizeof (float );
119119
120- const int id = ggml_cuda_get_device ();
121- const int smpbo = ggml_cuda_info ().devices [id].smpbo ;
120+ const int id = ggml_cuda_get_device ();
121+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
122122
123123 ggml_cuda_pool_alloc<float > dst_tmp (pool, blocks_num.x );
124124
@@ -169,10 +169,10 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
169169
170170 const dim3 blocks_dim (WARP_SIZE, 1 , 1 );
171171 const dim3 blocks_num (nrows, 1 , 1 );
172- const int nbytes_shared = ne00*sizeof (float );
172+ const size_t nbytes_shared = ne00*sizeof (float );
173173
174- const int id = ggml_cuda_get_device ();
175- const int smpbo = ggml_cuda_info ().devices [id].smpbo ;
174+ const int id = ggml_cuda_get_device ();
175+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
176176
177177 if (nbytes_shared <= smpbo) {
178178#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
0 commit comments