Skip to content

Commit 34e5142

Browse files
committed
Review: refactor switch statement, change cross_entropy to use full size
1 parent b9bcb7d commit 34e5142

File tree

3 files changed

+30
-56
lines changed

3 files changed

+30
-56
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ static const char * cu_get_error_str(CUresult err) {
187187
} while (0)
188188
#else
189189
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
190-
#endif
190+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
191191

192192
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
193193
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)

ggml/src/ggml-cuda/cross-entropy-loss.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
123123
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
124124

125125
if (nbytes_shared <= smpbo) {
126-
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), nbytes_shared);
126+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
127127
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
128128
} else {
129129
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -169,7 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
169169
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
170170

171171
if (nbytes_shared <= smpbo) {
172-
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), nbytes_shared);
172+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
173173
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
174174
} else {
175175
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);

ggml/src/ggml-cuda/softmax.cu

Lines changed: 27 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,34 @@ static __global__ void soft_max_back_f32(
182182
}
183183
}
184184

185-
template<int... Ns>
186-
void increase_shared_mem_limits(std::size_t smpbo)
185+
template<int... Ns, typename T>
186+
static void launch_soft_max_kernels(float * x, const T * mask, float * dst,
187+
const soft_max_params & p, cudaStream_t stream)
187188
{
188-
auto apply_limit = [smpbo](auto I) {
189-
constexpr int ncols = decltype(I)::value;
190-
constexpr int block = (ncols > 1024 ? 1024 : ncols);
191-
192-
CUDA_SET_SHARED_MEMORY_LIMIT(
193-
(soft_max_f32<true, ncols, block, half >), smpbo);
194-
CUDA_SET_SHARED_MEMORY_LIMIT(
195-
(soft_max_f32<true, ncols, block, float>), smpbo);
189+
const int id = ggml_cuda_get_device();
190+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
191+
192+
auto launch_kernel = [=](auto I) -> bool {
193+
constexpr int ncols = decltype(I)::value;
194+
constexpr int block = (ncols > 1024 ? 1024 : ncols);
195+
196+
if (p.ncols == ncols) {
197+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
198+
soft_max_f32<true, ncols, block><<<p.ne01, p.ne02, p.ne03, stream>>>
199+
(x, mask, dst, p);
200+
return true;
201+
}
202+
return false;
196203
};
197204

198-
//unary fold
199-
( apply_limit(std::integral_constant<int, Ns>{}), ... );
205+
// unary fold over launch_kernel
206+
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
207+
return;
208+
}
209+
210+
//default case
211+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
212+
soft_max_f32<true, 0, 0><<<p.ne01, p.ne02, p.ne03, stream>>>(x, mask, dst, p);
200213
}
201214

202215

@@ -217,47 +230,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
217230

218231

219232
if (nbytes_shared <= smpbo) {
220-
221-
increase_shared_mem_limits<0, 32, 64, 128, 256, 512, 1024, 2048, 4096>(smpbo);
222-
223-
switch (ncols_x) {
224-
case 32:
225-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
226-
(x, mask, dst, params);
227-
break;
228-
case 64:
229-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
230-
(x, mask, dst, params);
231-
break;
232-
case 128:
233-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
234-
(x, mask, dst, params);
235-
break;
236-
case 256:
237-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
238-
(x, mask, dst, params);
239-
break;
240-
case 512:
241-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
242-
(x, mask, dst, params);
243-
break;
244-
case 1024:
245-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
246-
(x, mask, dst, params);
247-
break;
248-
case 2048:
249-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
250-
(x, mask, dst, params);
251-
break;
252-
case 4096:
253-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
254-
(x, mask, dst, params);
255-
break;
256-
default:
257-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
258-
(x, mask, dst, params);
259-
break;
260-
}
233+
234+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream);
261235
} else {
262236
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
263237
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);

0 commit comments

Comments
 (0)