@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
123123 ggml_cuda_pool_alloc<float > dst_tmp (pool, blocks_num.x );
124124
125125 if (nbytes_shared <= smpbo) {
126- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
127- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false };
128- if (!shared_memory_limit_raised[id]) {
129- CUDA_CHECK (cudaFuncSetAttribute (cross_entropy_loss_f32<true >, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
130- shared_memory_limit_raised[id] = true ;
131- }
132- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
126+ CUDA_SET_SHARED_MEMORY_LIMIT ((cross_entropy_loss_f32<true >), nbytes_shared);
133127 cross_entropy_loss_f32<true ><<<blocks_num, blocks_dim, nbytes_shared, stream>>> (src0_d, src1_d, dst_tmp.ptr , ne00, nrows);
134128 } else {
135129 cross_entropy_loss_f32<false ><<<blocks_num, blocks_dim, 0 , stream>>> (src0_d, src1_d, dst_tmp.ptr , ne00, nrows);
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
175169 const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
176170
177171 if (nbytes_shared <= smpbo) {
178- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
179- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false };
180- if (!shared_memory_limit_raised[id]) {
181- CUDA_CHECK (cudaFuncSetAttribute (cross_entropy_loss_back_f32<true >, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
182- shared_memory_limit_raised[id] = true ;
183- }
184- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
172+ CUDA_SET_SHARED_MEMORY_LIMIT ((cross_entropy_loss_back_f32<true >), nbytes_shared);
185173 cross_entropy_loss_back_f32<true ><<<blocks_num, blocks_dim, nbytes_shared, stream>>> (grad_d, src0f_d, src1f_d, dst_d, ne00);
186174 } else {
187175 cross_entropy_loss_back_f32<false ><<<blocks_num, blocks_dim, 0 , stream>>> (grad_d, src0f_d, src1f_d, dst_d, ne00);
0 commit comments