@@ -151,21 +151,37 @@ static __global__ void soft_max_back_f32(
151151 }
152152}
153153
154- template <int ... Ns>
155- void increase_shared_mem_limits (std::size_t smpbo)
154+ template <int ... Ns, typename T>
155+ static void launch_soft_max_kernels (int ncols_x, const float * x, const T * mask, float * dst,
156+ int ncols_param, int nrows_y, float scale, float max_bias,
157+ float m0, float m1, uint32_t n_head_log2, dim3 block_nums,
158+ dim3 block_dims, size_t nbytes_shared, cudaStream_t stream)
156159{
157- auto apply_limit = [smpbo](auto I) {
158- constexpr int ncols = decltype (I)::value;
159- constexpr int block = (ncols > 1024 ? 1024 : ncols);
160-
161- CUDA_SET_SHARED_MEMORY_LIMIT (
162- (soft_max_f32<true , ncols, block, half >), smpbo);
163- CUDA_SET_SHARED_MEMORY_LIMIT (
164- (soft_max_f32<true , ncols, block, float >), smpbo);
160+ const int id = ggml_cuda_get_device ();
161+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
162+
163+ auto launch_kernel = [=](auto I) -> bool {
164+ constexpr int ncols = decltype (I)::value;
165+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
166+
167+ if (ncols_x == ncols) {
168+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
169+ soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
170+ (x, mask, dst, ncols_param, nrows_y, scale, max_bias, m0, m1, n_head_log2);
171+ return true ;
172+ }
173+ return false ;
165174 };
166175
167- // unary fold
168- ( apply_limit (std::integral_constant<int , Ns>{}), ... );
176+ // unary fold over launch_kernel
177+ if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
178+ return ;
179+ }
180+
181+ // default case
182+ CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
183+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
184+ (x, mask, dst, ncols_param, nrows_y, scale, max_bias, m0, m1, n_head_log2);
169185}
170186
171187
@@ -189,47 +205,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
189205
190206
191207 if (nbytes_shared <= smpbo) {
192-
193- increase_shared_mem_limits<0 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(smpbo);
194-
195- switch (ncols_x) {
196- case 32 :
197- soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
198- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
199- break ;
200- case 64 :
201- soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
202- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
203- break ;
204- case 128 :
205- soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
206- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
207- break ;
208- case 256 :
209- soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
210- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
211- break ;
212- case 512 :
213- soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
214- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
215- break ;
216- case 1024 :
217- soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
218- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
219- break ;
220- case 2048 :
221- soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
222- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
223- break ;
224- case 4096 :
225- soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
226- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
227- break ;
228- default :
229- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
230- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
231- break ;
232- }
208+ launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(
209+ ncols_x, x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, nbytes_shared, stream);
233210 } else {
234211 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
235212 soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
0 commit comments