11#include " softmax.hpp"
22
3- template <bool vals_smem, int ncols_template, int block_size_template, typename T>
4- static void soft_max_f32 (const float * x, const T * mask, float * dst, const int ncols_par ,
3+ template <typename T>
4+ static void soft_max_f32 (const float * x, const T * mask, float * dst, int ncols ,
55 const int nrows_y, const float scale, const float max_bias, const float m0,
6- const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf) {
7- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
6+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf,
7+ const bool vals_smem, const bool check_columns_count) {
88
99 const int tid = item_ct1.get_local_id (2 );
1010 const int rowx = item_ct1.get_group (2 );
1111 const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
1212
13- const int block_size = block_size_template == 0 ? item_ct1.get_local_range (2 ) : block_size_template ;
13+ const int block_size = check_columns_count ? item_ct1.get_local_range (2 ) : std::min (ncols, 1024 ) ;
1414
1515 const int warp_id = item_ct1.get_local_id (2 ) / WARP_SIZE;
1616 const int lane_id = item_ct1.get_local_id (2 ) % WARP_SIZE;
@@ -35,7 +35,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
3535 for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
3636 const int col = col0 + tid;
3737
38- if (ncols_template == 0 && col >= ncols) {
38+ if (check_columns_count && col >= ncols) {
3939 break ;
4040 }
4141
@@ -74,7 +74,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
7474#pragma unroll
7575 for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
7676 const int col = col0 + tid;
77- if (ncols_template == 0 && col >= ncols) {
77+ if (check_columns_count && col >= ncols) {
7878 break ;
7979 }
8080
@@ -113,7 +113,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
113113 for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
114114 const int col = col0 + tid;
115115
116- if (ncols_template == 0 && col >= ncols) {
116+ if (check_columns_count && col >= ncols) {
117117 return ;
118118 }
119119
@@ -122,25 +122,6 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
122122 }
123123}
124124
125- template <bool vals_smem, int ncols_template, int block_size_template, typename T>
126- static void soft_max_f32_submitter (const float * x, const T * mask, float * dst, const int ncols_par,
127- const int nrows_y, const float scale, const float max_bias, const float m0,
128- const float m1, uint32_t n_head_log2, sycl::range<3 > block_nums, sycl::range<3 > block_dims,
129- const size_t n_local_scratch, queue_ptr stream) {
130- stream->submit ([&](sycl::handler &cgh) {
131- sycl::local_accessor<float , 1 > local_buf_acc (n_local_scratch, cgh);
132-
133- cgh.parallel_for (
134- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
135- [=](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
136- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137- nrows_y, scale, max_bias, m0,
138- m1, n_head_log2, item_ct1,
139- get_pointer (local_buf_acc));
140- });
141- });
142- }
143-
144125template <typename T>
145126static void soft_max_f32_sycl (const float * x, const T * mask,
146127 float * dst, const int ncols_x, const int nrows_x,
@@ -163,64 +144,28 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
163144 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
164145
165146 const size_t local_mem_size = stream->get_device ().get_info <sycl::info::device::local_mem_size>();
166- if (n_local_scratch*sizeof (float ) < local_mem_size) {
167- if (ncols_x > max_block_size) {
168- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
169- max_bias, m0, m1, n_head_log2, block_nums,
170- block_dims, n_local_scratch, stream);
171- return ;
172- }
173- switch (ncols_x) {
174- case 32 :
175- soft_max_f32_submitter<true , 32 , 32 >(x, mask, dst, ncols_x, nrows_y, scale,
176- max_bias, m0, m1, n_head_log2, block_nums,
177- block_dims, n_local_scratch, stream);
178- break ;
179- case 64 :
180- soft_max_f32_submitter<true , 64 , 64 >(x, mask, dst, ncols_x, nrows_y, scale,
181- max_bias, m0, m1, n_head_log2, block_nums,
182- block_dims, n_local_scratch, stream);
183- break ;
184- case 128 :
185- soft_max_f32_submitter<true , 128 , 128 >(x, mask, dst, ncols_x, nrows_y, scale,
186- max_bias, m0, m1, n_head_log2, block_nums,
187- block_dims, n_local_scratch, stream);
188- break ;
189- case 256 :
190- soft_max_f32_submitter<true , 256 , 256 >(x, mask, dst, ncols_x, nrows_y, scale,
191- max_bias, m0, m1, n_head_log2, block_nums,
192- block_dims, n_local_scratch, stream);
193- break ;
194- case 512 :
195- soft_max_f32_submitter<true , 512 , 512 >(x, mask, dst, ncols_x, nrows_y, scale,
196- max_bias, m0, m1, n_head_log2, block_nums,
197- block_dims, n_local_scratch, stream);
198- break ;
199- case 1024 :
200- soft_max_f32_submitter<true , 1024 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
201- max_bias, m0, m1, n_head_log2, block_nums,
202- block_dims, n_local_scratch, stream);
203- break ;
204- case 2048 :
205- soft_max_f32_submitter<true , 2048 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
206- max_bias, m0, m1, n_head_log2, block_nums,
207- block_dims, n_local_scratch, stream);
208- break ;
209- case 4096 :
210- soft_max_f32_submitter<true , 4096 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
211- max_bias, m0, m1, n_head_log2, block_nums,
212- block_dims, n_local_scratch, stream);
213- break ;
214- default :
215- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
216- max_bias, m0, m1, n_head_log2, block_nums,
217- block_dims, n_local_scratch, stream);
218- break ;
219- }
147+
148+ auto soft_max_f32_submit = [=](size_t scratch_size, bool vals_smem, bool check_columns_count) {
149+ stream->submit ([=](sycl::handler &cgh) {
150+ sycl::local_accessor<float , 1 > local_buf_acc (scratch_size, cgh);
151+ cgh.parallel_for (
152+ sycl::nd_range<3 >(block_nums * block_dims, block_dims),
153+ [=](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
154+ soft_max_f32 (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2,
155+ item_ct1, get_pointer (local_buf_acc), vals_smem, check_columns_count);
156+ });
157+ });
158+ };
159+
160+ if (n_local_scratch*sizeof (float ) >= local_mem_size) {
161+ soft_max_f32_submit (WARP_SIZE, false , true );
162+ } else if (ncols_x > max_block_size) {
163+ soft_max_f32_submit (n_local_scratch, true , true );
164+ } else if (ncols_x == 32 || ncols_x == 64 || ncols_x == 128 || ncols_x == 256
165+ || ncols_x == 512 || ncols_x == 1024 || ncols_x == 2048 || ncols_x == 4096 ) {
166+ soft_max_f32_submit (n_local_scratch, true , false );
220167 } else {
221- soft_max_f32_submitter<false , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
222- max_bias, m0, m1, n_head_log2, block_nums,
223- block_dims, WARP_SIZE, stream);
168+ soft_max_f32_submit (n_local_scratch, true , true );
224169 }
225170}
226171
0 commit comments