22#include " ggml.h"
33#include " softmax.cuh"
44#include < cstdint>
5+ #include < utility>
56
67template <typename T>
78static __device__ __forceinline__ float t2f32 (T val) {
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
181182 }
182183}
183184
185+ template <int ... Ns, typename T>
186+ static void launch_soft_max_kernels (const float * x, const T * mask, float * dst,
187+ const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188+ {
189+ const int id = ggml_cuda_get_device ();
190+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
191+
192+ auto launch_kernel = [=](auto I) -> bool {
193+ constexpr int ncols = decltype (I)::value;
194+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
195+
196+ if (p.ncols == ncols) {
197+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
198+ soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199+ (x, mask, dst, p);
200+ return true ;
201+ }
202+ return false ;
203+ };
204+
205+ // unary fold over launch_kernel
206+ if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
207+ return ;
208+ }
209+
210+ // default case
211+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
212+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>> (x, mask, dst, p);
213+ }
214+
215+
184216template <typename T>
185217static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
186218 int nth = WARP_SIZE;
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
193225 static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
194226
195227
196- // FIXME: this limit could be raised by ~2-4x on Ampere or newer
197- if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
198- switch (ncols_x) {
199- case 32 :
200- soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
201- (x, mask, dst, params);
202- break ;
203- case 64 :
204- soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205- (x, mask, dst, params);
206- break ;
207- case 128 :
208- soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
209- (x, mask, dst, params);
210- break ;
211- case 256 :
212- soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
213- (x, mask, dst, params);
214- break ;
215- case 512 :
216- soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
217- (x, mask, dst, params);
218- break ;
219- case 1024 :
220- soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
221- (x, mask, dst, params);
222- break ;
223- case 2048 :
224- soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
225- (x, mask, dst, params);
226- break ;
227- case 4096 :
228- soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
229- (x, mask, dst, params);
230- break ;
231- default :
232- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
233- (x, mask, dst, params);
234- break ;
235- }
228+ const int id = ggml_cuda_get_device ();
229+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
230+
231+
232+ if (nbytes_shared <= smpbo) {
233+ launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
236234 } else {
237235 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
238236 soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params);
0 commit comments