@@ -13,6 +13,28 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1313 return __half2float (val);
1414}
1515
16+ struct soft_max_params {
17+
18+ int64_t nheads;
19+ uint32_t n_head_log2;
20+ int64_t ncols;
21+ int64_t nrows_x;
22+ int64_t nrows_y;
23+ int64_t ne00;
24+ int64_t ne01;
25+ int64_t ne02;
26+ int64_t nb11;
27+ int64_t nb12;
28+ int64_t nb13;
29+
30+ int64_t ne12;
31+ int64_t ne13;
32+ float scale;
33+ float max_bias;
34+ float m0;
35+ float m1;
36+ };
37+
1638// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
1739// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
1840#ifdef __clang__
@@ -21,24 +43,30 @@ __device__ float __forceinline__ t2f32<half>(half val) {
2143#endif // __clang__
2244template <bool use_shared, int ncols_template, int block_size_template, typename T>
2345static __global__ void soft_max_f32 (
24- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
46+ const float * x, const T * mask, float * dst, const soft_max_params p) {
47+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
2748
2849 const int tid = threadIdx .x ;
2950 const int rowx = blockIdx .x ;
30- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
51+
52+ const int64_t i03 = rowx / (p.ne01 * p.ne02 );
53+ const int64_t i02 = (rowx % (p.ne01 * p.ne02 )) / p.ne01 ;
54+ const int64_t i01 = rowx % p.ne01 ;
55+
56+ const int64_t i11 = i01;
57+ const int64_t i12 = i02 % p.ne12 ;
58+ const int64_t i13 = i03 % p.ne13 ;
3159
3260 x += int64_t (rowx)*ncols;
33- mask += int64_t (rowy)*ncols * (mask != nullptr );
61+ mask += ( int64_t (i11)*p. nb11 + int64_t (i12)*p. nb12 + int64_t (i13)*p. nb13 ) / sizeof (T) * (mask != nullptr );
3462 dst += int64_t (rowx)*ncols;
3563
3664 const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
3765
3866 const int warp_id = threadIdx .x / WARP_SIZE;
3967 const int lane_id = threadIdx .x % WARP_SIZE;
4068
41- const float slope = get_alibi_slope (max_bias, rowx/nrows_y, n_head_log2, m0, m1);
69+ const float slope = get_alibi_slope (p. max_bias , i02, p. n_head_log2 , p. m0 , p. m1 );
4270
4371 extern __shared__ float data_soft_max_f32[];
4472 float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -55,7 +83,7 @@ static __global__ void soft_max_f32(
5583 break ;
5684 }
5785
58- const float val = x[col]*scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
86+ const float val = x[col]*p. scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
5987
6088 vals[col] = val;
6189 max_val = max (max_val, val);
@@ -151,63 +179,60 @@ static __global__ void soft_max_back_f32(
151179}
152180
153181template <typename T>
154- static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias , cudaStream_t stream) {
182+ static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, soft_max_params params , cudaStream_t stream) {
155183 int nth = WARP_SIZE;
184+ const int64_t ncols_x = params.ncols ;
185+
156186 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
157187 const dim3 block_dims (nth, 1 , 1 );
158- const dim3 block_nums (nrows_x, 1 , 1 );
188+ const dim3 block_nums (params. nrows_x , 1 , 1 );
159189 const size_t nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
160190 static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
161191
162- const uint32_t n_head = nrows_x/nrows_y;
163- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
164-
165- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
166- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
167192
168193 // FIXME: this limit could be raised by ~2-4x on Ampere or newer
169194 if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
170195 switch (ncols_x) {
171196 case 32 :
172197 soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
173- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
198+ (x, mask, dst, params );
174199 break ;
175200 case 64 :
176201 soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
177- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
202+ (x, mask, dst, params );
178203 break ;
179204 case 128 :
180205 soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
181- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
206+ (x, mask, dst, params );
182207 break ;
183208 case 256 :
184209 soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
185- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
210+ (x, mask, dst, params );
186211 break ;
187212 case 512 :
188213 soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
189- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
214+ (x, mask, dst, params );
190215 break ;
191216 case 1024 :
192217 soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
193- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
218+ (x, mask, dst, params );
194219 break ;
195220 case 2048 :
196221 soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
197- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
222+ (x, mask, dst, params );
198223 break ;
199224 case 4096 :
200225 soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
201- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
226+ (x, mask, dst, params );
202227 break ;
203228 default :
204229 soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
230+ (x, mask, dst, params );
206231 break ;
207232 }
208233 } else {
209234 const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
210- soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
235+ soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params );
211236 }
212237}
213238
@@ -235,10 +260,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
235260
236261 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
237262
238- const int64_t ne00 = src0->ne [0 ];
239263 const int64_t nrows_x = ggml_nrows (src0);
240264 const int64_t nrows_y = src0->ne [1 ];
241265
266+ const int64_t ne00 = src0->ne [0 ];
267+
242268 float scale = 1 .0f ;
243269 float max_bias = 0 .0f ;
244270
@@ -247,10 +273,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247273
248274 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
249275
276+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
277+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
278+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
279+
280+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
281+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
282+
283+ const uint32_t n_head = src0->ne [2 ];
284+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
285+
286+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
287+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
288+
289+
290+ soft_max_params params = {
291+ .nheads = src0->ne [2 ],
292+ .n_head_log2 = n_head_log2,
293+ .ncols = ne00,
294+ .nrows_x = nrows_x,
295+ .nrows_y = nrows_y,
296+ .ne00 = src0->ne [0 ],
297+ .ne01 = src0->ne [1 ],
298+ .ne02 = src0->ne [2 ],
299+ .nb11 = nb11,
300+ .nb12 = nb12,
301+ .nb13 = nb13,
302+ .ne12 = ne12,
303+ .ne13 = ne13,
304+ .scale = scale,
305+ .max_bias = max_bias,
306+ .m0 = m0,
307+ .m1 = m1
308+ };
309+
250310 if (use_f16) {
251- soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
311+ soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, params , stream);
252312 } else {
253- soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
313+ soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, params , stream);
254314 }
255315}
256316
0 commit comments