@@ -182,21 +182,34 @@ static __global__ void soft_max_back_f32(
182182 }
183183}
184184
185- template <int ... Ns>
186- void increase_shared_mem_limits (std::size_t smpbo)
185+ template <int ... Ns, typename T>
186+ static void launch_soft_max_kernels (float * x, const T * mask, float * dst,
187+ const soft_max_params & p, cudaStream_t stream)
187188{
188- auto apply_limit = [smpbo](auto I) {
189- constexpr int ncols = decltype (I)::value;
190- constexpr int block = (ncols > 1024 ? 1024 : ncols);
191-
192- CUDA_SET_SHARED_MEMORY_LIMIT (
193- (soft_max_f32<true , ncols, block, half >), smpbo);
194- CUDA_SET_SHARED_MEMORY_LIMIT (
195- (soft_max_f32<true , ncols, block, float >), smpbo);
189+ const int id = ggml_cuda_get_device ();
190+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
191+
192+ auto launch_kernel = [=](auto I) -> bool {
193+ constexpr int ncols = decltype (I)::value;
194+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
195+
196+ if (p.ncols == ncols) {
197+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
198+ soft_max_f32<true , ncols, block><<<p.ne01, p.ne02, p.ne03, stream>>>
199+ (x, mask, dst, p);
200+ return true ;
201+ }
202+ return false ;
196203 };
197204
198- // unary fold
199- ( apply_limit (std::integral_constant<int , Ns>{}), ... );
205+ // unary fold over launch_kernel
206+ if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
207+ return ;
208+ }
209+
210+ // default case
211+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
212+ soft_max_f32<true , 0 , 0 ><<<p.ne01, p.ne02, p.ne03, stream>>> (x, mask, dst, p);
200213}
201214
202215
@@ -217,47 +230,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
217230
218231
219232 if (nbytes_shared <= smpbo) {
220-
221- increase_shared_mem_limits<0 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(smpbo);
222-
223- switch (ncols_x) {
224- case 32 :
225- soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
226- (x, mask, dst, params);
227- break ;
228- case 64 :
229- soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
230- (x, mask, dst, params);
231- break ;
232- case 128 :
233- soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
234- (x, mask, dst, params);
235- break ;
236- case 256 :
237- soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
238- (x, mask, dst, params);
239- break ;
240- case 512 :
241- soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
242- (x, mask, dst, params);
243- break ;
244- case 1024 :
245- soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
246- (x, mask, dst, params);
247- break ;
248- case 2048 :
249- soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
250- (x, mask, dst, params);
251- break ;
252- case 4096 :
253- soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
254- (x, mask, dst, params);
255- break ;
256- default :
257- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
258- (x, mask, dst, params);
259- break ;
260- }
233+
234+ launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(x, mask, dst, params, stream);
261235 } else {
262236 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
263237 soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params);
0 commit comments