22#include " ggml.h"
33#include " softmax.cuh"
44#include < cstdint>
5+ #include < utility>
56
67template <typename T>
78static __device__ __forceinline__ float t2f32 (T val) {
@@ -188,6 +189,37 @@ static __global__ void soft_max_back_f32(
188189 }
189190}
190191
192+ template <int ... Ns, typename T>
193+ static void launch_soft_max_kernels (const float * x, const T * mask, float * dst,
194+ const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
195+ {
196+ const int id = ggml_cuda_get_device ();
197+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
198+
199+ auto launch_kernel = [=](auto I) -> bool {
200+ constexpr int ncols = decltype (I)::value;
201+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
202+
203+ if (p.ncols == ncols) {
204+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
205+ soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
206+ (x, mask, dst, p);
207+ return true ;
208+ }
209+ return false ;
210+ };
211+
212+ // unary fold over launch_kernel
213+ if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
214+ return ;
215+ }
216+
217+ // default case
218+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
219+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>> (x, mask, dst, p);
220+ }
221+
222+
191223template <typename T>
192224static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
193225 int nth = WARP_SIZE;
@@ -200,46 +232,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
200232 static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
201233
202234
203- // FIXME: this limit could be raised by ~2-4x on Ampere or newer
204- if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
205- switch (ncols_x) {
206- case 32 :
207- soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
208- (x, mask, dst, params);
209- break ;
210- case 64 :
211- soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
212- (x, mask, dst, params);
213- break ;
214- case 128 :
215- soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
216- (x, mask, dst, params);
217- break ;
218- case 256 :
219- soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
220- (x, mask, dst, params);
221- break ;
222- case 512 :
223- soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
224- (x, mask, dst, params);
225- break ;
226- case 1024 :
227- soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
228- (x, mask, dst, params);
229- break ;
230- case 2048 :
231- soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
232- (x, mask, dst, params);
233- break ;
234- case 4096 :
235- soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
236- (x, mask, dst, params);
237- break ;
238- default :
239- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
240- (x, mask, dst, params);
241- break ;
242- }
235+ const int id = ggml_cuda_get_device ();
236+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
237+
238+
239+ if (nbytes_shared <= smpbo) {
240+ launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
243241 } else {
244242 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
245243 soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params);
0 commit comments