|
1 | 1 | #include "softmax.hpp" |
2 | 2 |
|
3 | | -template <typename T> static inline float t2f32(T val) { |
4 | | - return static_cast<float>(val); |
5 | | -} |
6 | | - |
7 | | -template <> inline float t2f32<sycl::half>(sycl::half val) { |
8 | | - return static_cast<float>(val); |
9 | | -} |
10 | | - |
11 | 3 | template <bool vals_smem, int ncols_template, int block_size_template, typename T> |
12 | 4 | static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, |
13 | 5 | const int nrows_y, const float scale, const float max_bias, const float m0, |
@@ -51,7 +43,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int |
51 | 43 | const int ix = rowx*ncols + col; |
52 | 44 | const int iy = rowy*ncols + col; |
53 | 45 |
|
54 | | - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); |
| 46 | + const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f); |
55 | 47 |
|
56 | 48 | vals[col] = val; |
57 | 49 | max_val = sycl::max(max_val, val); |
@@ -142,7 +134,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, |
142 | 134 | cgh.parallel_for( |
143 | 135 | sycl::nd_range<3>(block_nums * block_dims, block_dims), |
144 | 136 | [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { |
145 | | - soft_max_f32<vals_smem, ncols_template, block_size_template, T>(x, mask, dst, ncols_par, |
| 137 | + soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par, |
146 | 138 | nrows_y, scale, max_bias, m0, |
147 | 139 | m1, n_head_log2, item_ct1, |
148 | 140 | get_pointer(local_buf_acc)); |
@@ -174,60 +166,60 @@ static void soft_max_f32_sycl(const float * x, const T * mask, |
174 | 166 | const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>(); |
175 | 167 | if (n_local_scratch*sizeof(float) < local_mem_size) { |
176 | 168 | if (ncols_x > max_block_size) { |
177 | | - soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 169 | + soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
178 | 170 | max_bias, m0, m1, n_head_log2, block_nums, |
179 | 171 | block_dims, n_local_scratch, stream); |
180 | 172 | return; |
181 | 173 | } |
182 | 174 | switch (ncols_x) { |
183 | 175 | case 32: |
184 | | - soft_max_f32_submitter<true, 32, 32, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 176 | + soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale, |
185 | 177 | max_bias, m0, m1, n_head_log2, block_nums, |
186 | 178 | block_dims, n_local_scratch, stream); |
187 | 179 | break; |
188 | 180 | case 64: |
189 | | - soft_max_f32_submitter<true, 64, 64, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 181 | + soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale, |
190 | 182 | max_bias, m0, m1, n_head_log2, block_nums, |
191 | 183 | block_dims, n_local_scratch, stream); |
192 | 184 | break; |
193 | 185 | case 128: |
194 | | - soft_max_f32_submitter<true, 128, 128, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 186 | + soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale, |
195 | 187 | max_bias, m0, m1, n_head_log2, block_nums, |
196 | 188 | block_dims, n_local_scratch, stream); |
197 | 189 | break; |
198 | 190 | case 256: |
199 | | - soft_max_f32_submitter<true, 256, 256, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 191 | + soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale, |
200 | 192 | max_bias, m0, m1, n_head_log2, block_nums, |
201 | 193 | block_dims, n_local_scratch, stream); |
202 | 194 | break; |
203 | 195 | case 512: |
204 | | - soft_max_f32_submitter<true, 512, 512, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 196 | + soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale, |
205 | 197 | max_bias, m0, m1, n_head_log2, block_nums, |
206 | 198 | block_dims, n_local_scratch, stream); |
207 | 199 | break; |
208 | 200 | case 1024: |
209 | | - soft_max_f32_submitter<true, 1024, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 201 | + soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
210 | 202 | max_bias, m0, m1, n_head_log2, block_nums, |
211 | 203 | block_dims, n_local_scratch, stream); |
212 | 204 | break; |
213 | 205 | case 2048: |
214 | | - soft_max_f32_submitter<true, 2048, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 206 | + soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
215 | 207 | max_bias, m0, m1, n_head_log2, block_nums, |
216 | 208 | block_dims, n_local_scratch, stream); |
217 | 209 | break; |
218 | 210 | case 4096: |
219 | | - soft_max_f32_submitter<true, 4096, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 211 | + soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
220 | 212 | max_bias, m0, m1, n_head_log2, block_nums, |
221 | 213 | block_dims, n_local_scratch, stream); |
222 | 214 | break; |
223 | 215 | default: |
224 | | - soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 216 | + soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
225 | 217 | max_bias, m0, m1, n_head_log2, block_nums, |
226 | 218 | block_dims, n_local_scratch, stream); |
227 | 219 | break; |
228 | 220 | } |
229 | 221 | } else { |
230 | | - soft_max_f32_submitter<false, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 222 | + soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
231 | 223 | max_bias, m0, m1, n_head_log2, block_nums, |
232 | 224 | block_dims, WARP_SIZE, stream); |
233 | 225 | } |
|
0 commit comments