11#include " softmax.cuh"
22
33template <bool vals_smem, int ncols_template, int block_size_template>
4- static __global__ void soft_max_f32 (const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
4+ static __global__ void soft_max_f32 (const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
55 const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
66
77 const int tid = threadIdx .x ;
@@ -43,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
4343 const int ix = rowx*ncols + col;
4444 const int iy = rowy*ncols + col;
4545
46- const float val = x[ix]*scale + (mask ? mask[iy] : 0 .0f ) + (pos ? slope*pos[col] : 0 .0f );
46+ const float val = x[ix]*scale + (mask ? __half2float ( mask[iy]) : 0 .0f ) + (pos ? slope*__half2float ( pos[col]) : 0 .0f );
4747
4848 vals[col] = val;
4949 max_val = max (max_val, val);
@@ -114,7 +114,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
114114 }
115115}
116116
117- static void soft_max_f32_cuda (const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
117+ static void soft_max_f32_cuda (const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
118118 int nth = WARP_SIZE;
119119 while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
120120 const dim3 block_dims (nth, 1 , 1 );
@@ -168,14 +168,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
168168 const ggml_tensor * src0 = dst->src [0 ];
169169 const ggml_tensor * src1 = dst->src [1 ];
170170 const float * src0_d = (const float *)src0->data ;
171- const float * src1_d = src1 ? (const float *)src1->data : nullptr ;
171+ const half * src1_d = src1 ? (const half *)src1->data : nullptr ;
172172 float * dst_d = (float *)dst->data ;
173173 cudaStream_t stream = ctx.stream ();
174174
175175 GGML_ASSERT (src0->type == GGML_TYPE_F32);
176176 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
177177
178- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32 ); // src1 contains mask and it is optional
178+ GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 ); // src1 contains mask and it is optional
179179
180180 const int64_t ne00 = src0->ne [0 ];
181181 const int64_t nrows_x = ggml_nrows (src0);
@@ -188,13 +188,13 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188188 memcpy (&max_bias, (float *) dst->op_params + 1 , sizeof (float ));
189189
190190 // positions tensor
191- float * src2_dd = nullptr ;
191+ half * src2_dd = nullptr ;
192192
193193 ggml_tensor * src2 = dst->src [2 ];
194194 const bool use_src2 = src2 != nullptr ;
195195
196196 if (use_src2) {
197- src2_dd = (float *)src2->data ;
197+ src2_dd = (half *)src2->data ;
198198 }
199199
200200 soft_max_f32_cuda (src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
0 commit comments