|
1 | 1 | #include "softmax.hpp" |
2 | 2 |
|
3 | | -template <typename T> static inline float t2f32(T val) { |
4 | | - return static_cast<float>(val); |
5 | | -} |
6 | | - |
7 | | -template <> inline float t2f32<sycl::half>(sycl::half val) { |
8 | | - return static_cast<float>(val); |
9 | | -} |
10 | | - |
11 | 3 | template <bool vals_smem, int ncols_template, int block_size_template, typename T> |
12 | 4 | static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, |
13 | 5 | const int nrows_y, const float scale, const float max_bias, const float m0, |
@@ -51,7 +43,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int |
51 | 43 | const int ix = rowx*ncols + col; |
52 | 44 | const int iy = rowy*ncols + col; |
53 | 45 |
|
54 | | - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); |
| 46 | + const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f); |
55 | 47 |
|
56 | 48 | vals[col] = val; |
57 | 49 | max_val = sycl::max(max_val, val); |
@@ -174,60 +166,60 @@ static void soft_max_f32_sycl(const float * x, const T * mask, |
174 | 166 | const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>(); |
175 | 167 | if (n_local_scratch*sizeof(float) < local_mem_size) { |
176 | 168 | if (ncols_x > max_block_size) { |
177 | | - soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 169 | + soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
178 | 170 | max_bias, m0, m1, n_head_log2, block_nums, |
179 | 171 | block_dims, n_local_scratch, stream); |
180 | 172 | return; |
181 | 173 | } |
182 | 174 | switch (ncols_x) { |
183 | 175 | case 32: |
184 | | - soft_max_f32_submitter<true, 32, 32, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 176 | + soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale, |
185 | 177 | max_bias, m0, m1, n_head_log2, block_nums, |
186 | 178 | block_dims, n_local_scratch, stream); |
187 | 179 | break; |
188 | 180 | case 64: |
189 | | - soft_max_f32_submitter<true, 64, 64, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 181 | + soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale, |
190 | 182 | max_bias, m0, m1, n_head_log2, block_nums, |
191 | 183 | block_dims, n_local_scratch, stream); |
192 | 184 | break; |
193 | 185 | case 128: |
194 | | - soft_max_f32_submitter<true, 128, 128, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 186 | + soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale, |
195 | 187 | max_bias, m0, m1, n_head_log2, block_nums, |
196 | 188 | block_dims, n_local_scratch, stream); |
197 | 189 | break; |
198 | 190 | case 256: |
199 | | - soft_max_f32_submitter<true, 256, 256, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 191 | + soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale, |
200 | 192 | max_bias, m0, m1, n_head_log2, block_nums, |
201 | 193 | block_dims, n_local_scratch, stream); |
202 | 194 | break; |
203 | 195 | case 512: |
204 | | - soft_max_f32_submitter<true, 512, 512, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 196 | + soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale, |
205 | 197 | max_bias, m0, m1, n_head_log2, block_nums, |
206 | 198 | block_dims, n_local_scratch, stream); |
207 | 199 | break; |
208 | 200 | case 1024: |
209 | | - soft_max_f32_submitter<true, 1024, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 201 | + soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
210 | 202 | max_bias, m0, m1, n_head_log2, block_nums, |
211 | 203 | block_dims, n_local_scratch, stream); |
212 | 204 | break; |
213 | 205 | case 2048: |
214 | | - soft_max_f32_submitter<true, 2048, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 206 | + soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
215 | 207 | max_bias, m0, m1, n_head_log2, block_nums, |
216 | 208 | block_dims, n_local_scratch, stream); |
217 | 209 | break; |
218 | 210 | case 4096: |
219 | | - soft_max_f32_submitter<true, 4096, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 211 | + soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
220 | 212 | max_bias, m0, m1, n_head_log2, block_nums, |
221 | 213 | block_dims, n_local_scratch, stream); |
222 | 214 | break; |
223 | 215 | default: |
224 | | - soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 216 | + soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
225 | 217 | max_bias, m0, m1, n_head_log2, block_nums, |
226 | 218 | block_dims, n_local_scratch, stream); |
227 | 219 | break; |
228 | 220 | } |
229 | 221 | } else { |
230 | | - soft_max_f32_submitter<false, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 222 | + soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
231 | 223 | max_bias, m0, m1, n_head_log2, block_nums, |
232 | 224 | block_dims, WARP_SIZE, stream); |
233 | 225 | } |
|
0 commit comments