Skip to content

Commit c15eca8

Browse files
Merge pull request #199 from menloresearch/update-dev-from-master-2025-08-09-00-12
Sync master with upstream release b6121
2 parents f68cb3c + e54d41b commit c15eca8

File tree

19 files changed

+329
-128
lines changed

19 files changed

+329
-128
lines changed

ggml/src/ggml-blas/ggml-blas.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,10 @@ ggml_backend_t ggml_backend_blas_init(void) {
281281
ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
282282

283283
ggml_backend_t backend = new ggml_backend {
284-
/* .guid = */ ggml_backend_blas_guid(),
285-
/* .interface = */ blas_backend_i,
286-
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
287-
/* .context = */ ctx,
284+
/* .guid = */ ggml_backend_blas_guid(),
285+
/* .iface = */ blas_backend_i,
286+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
287+
/* .context = */ ctx,
288288
};
289289

290290
#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) {
214214
ctx->abort_callback_data = NULL;
215215

216216
ggml_backend_t cpu_backend = new ggml_backend {
217-
/* .guid = */ ggml_backend_cpu_guid(),
218-
/* .interface = */ ggml_backend_cpu_i,
219-
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
220-
/* .context = */ ctx,
217+
/* .guid = */ ggml_backend_cpu_guid(),
218+
/* .iface = */ ggml_backend_cpu_i,
219+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
220+
/* .context = */ ctx,
221221
};
222222

223223
if (cpu_backend == NULL) {

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
785785
const half2 * const __restrict__ K_h2,
786786
const half2 * const __restrict__ V_h2,
787787
const half2 * const __restrict__ mask_h2,
788+
const float * const __restrict__ sinks_f,
788789
float2 * const __restrict__ dstk,
789790
float2 * const __restrict__ dstk_fixup,
790791
const float scale,
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
957958
}
958959
}
959960

961+
// If attention sinks are used, potentially re-scale if KQ_max is small.
962+
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
963+
// so it's being done unconditionally for every thread.
964+
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
965+
float KQ_max_scale[cols_per_thread];
966+
#pragma unroll
967+
for (int col = 0; col < cols_per_thread; ++col) {
968+
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
969+
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
970+
const float sink = sinks_f[jc % ncols2];
971+
972+
const float KQ_max_new = fmaxf(KQ_max[col], sink);
973+
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
974+
KQ_max_scale[col] = expf(KQ_max_diff);
975+
KQ_max[col] = KQ_max_new;
976+
977+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
978+
979+
const float KQ_max_add = expf(sink - KQ_max_new);
980+
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
981+
}
982+
983+
if (ntiles == 1) {
984+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
985+
#pragma unroll
986+
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
987+
#pragma unroll
988+
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
989+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
990+
}
991+
}
992+
} else {
993+
#pragma unroll
994+
for (int col = 0; col < cols_per_thread; ++col) {
995+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
996+
#pragma unroll
997+
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
998+
#pragma unroll
999+
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
1000+
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
1001+
}
1002+
}
1003+
}
1004+
}
1005+
}
1006+
9601007
// Combine VKQ accumulator values if np > 1.
9611008
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
9621009
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16(
12711318

12721319
while (kbc < kbc_stop && kb0_stop == iter_k) {
12731320
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1274-
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1275-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1321+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1322+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
12761323

1277-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1278-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1324+
const int head0 = zt * ncols2;
1325+
1326+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1327+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
12791328
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
12801329
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1281-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1330+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
12821331

1283-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1332+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1333+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
12841334

1285-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1335+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
12861336

12871337
const int kb0_start_kernel = kb0_start * kb_niter;
12881338
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
12951345
if (kb0_start == 0) {
12961346
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
12971347
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1298-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1348+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
12991349
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13001350
} else {
13011351
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
13021352
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1303-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1353+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
13041354
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13051355
}
13061356

@@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
13161366
}
13171367

13181368
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1319-
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1320-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1369+
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1370+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1371+
1372+
const int head0 = zt * ncols2;
13211373

1322-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1323-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1374+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1375+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
13241376
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
13251377
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1326-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1378+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
13271379

1328-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1380+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1381+
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
13291382

1330-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1383+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
13311384

13321385
const int kb0_start_kernel = kb0_start * kb_niter;
13331386
int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16(
13391392
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
13401393
constexpr bool needs_fixup = false;
13411394
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1342-
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1395+
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
13431396
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13441397
#else
13451398
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
282282
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
283283

284284
// TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
285-
if (sinks) {
285+
if (sinks && !fp16_mma_available(cc)) {
286286
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
287287
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
288288
} else {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
35323532
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
35333533
}
35343534
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
3535-
if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
3535+
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
3536+
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
35363537
return false;
35373538
}
35383539
if (op->src[0]->ne[0] == 192) {
@@ -3798,10 +3799,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
37983799
}
37993800

38003801
ggml_backend_t cuda_backend = new ggml_backend {
3801-
/* .guid = */ ggml_backend_cuda_guid(),
3802-
/* .interface = */ ggml_backend_cuda_interface,
3803-
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
3804-
/* .context = */ ctx,
3802+
/* .guid = */ ggml_backend_cuda_guid(),
3803+
/* .iface = */ ggml_backend_cuda_interface,
3804+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
3805+
/* .context = */ ctx,
38053806
};
38063807

38073808
return cuda_backend;

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25202520
case GGML_OP_CLAMP:
25212521
return op->src[0]->type == GGML_TYPE_F32;
25222522
case GGML_OP_SOFT_MAX:
2523-
// TODO: support attention sinks [TAG_ATTN_SINKS]
2524-
return op->src[2] == nullptr;
25252523
case GGML_OP_NORM:
25262524
case GGML_OP_RMS_NORM:
25272525
return true;
@@ -2626,10 +2624,10 @@ ggml_backend_t ggml_backend_opencl_init(void) {
26262624
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
26272625

26282626
ggml_backend_t backend = new ggml_backend {
2629-
/* .guid = */ ggml_backend_opencl_guid(),
2630-
/* .interface = */ ggml_backend_opencl_i,
2631-
/* .device = */ dev,
2632-
/* .context = */ backend_ctx
2627+
/* .guid = */ ggml_backend_opencl_guid(),
2628+
/* .iface = */ ggml_backend_opencl_i,
2629+
/* .device = */ dev,
2630+
/* .context = */ backend_ctx
26332631
};
26342632

26352633
return backend;
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
65946592
GGML_ASSERT(src1->extra);
65956593
}
65966594

6595+
const ggml_tensor * src2 = dst->src[2];
6596+
if (src2) {
6597+
GGML_ASSERT(src2->extra);
6598+
}
6599+
65976600
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
65986601

65996602
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
66006603
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
66016604

66026605
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
6606+
ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
66036607

66046608
cl_ulong offset0 = extra0->offset + src0->view_offs;
66056609
cl_ulong offsetd = extrad->offset + dst->view_offs;
66066610

66076611
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
6612+
cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
66086613

66096614
const int ne00 = src0->ne[0];
66106615
const int ne01 = src0->ne[1];
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
66726677
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
66736678
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device));
66746679
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6675-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6676-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6677-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
6678-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
6679-
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
6680-
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
6681-
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
6682-
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
6683-
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
6684-
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
6685-
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
6686-
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
6687-
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
6688-
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
6689-
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
6690-
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
6691-
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
6692-
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
6693-
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
6680+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device));
6681+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
6682+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
6683+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
6684+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
6685+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
6686+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
6687+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
6688+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
6689+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13));
6690+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
6691+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
6692+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
6693+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
6694+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
6695+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
6696+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale));
6697+
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias));
6698+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0));
6699+
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1));
6700+
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2));
66946701

66956702
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
66966703
size_t local_work_size[] = {(size_t)nth, 1, 1};

ggml/src/ggml-opencl/kernels/softmax_4_f16.cl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
2626
ulong offset0,
2727
global char * src1,
2828
ulong offset1,
29+
global char * src2,
30+
ulong offset2,
2931
global char * dst,
3032
ulong offsetd,
3133
int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
4850
) {
4951
src0 = src0 + offset0;
5052
src1 = src1 + offset1;
53+
src2 = src2 + offset2;
5154
dst = dst + offsetd;
5255

5356
int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
6063

6164
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
6265
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
66+
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
6367
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
6468

6569
float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
7579
}
7680

7781
// parallel max
78-
float4 lmax4 = -INFINITY;
82+
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
7983
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
8084
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
8185
}
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
9296
}
9397
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
9498

95-
const float sum = sub_group_reduce_add(lsum);
99+
float sum = sub_group_reduce_add(lsum);
100+
101+
if (psrc2) {
102+
sum += exp(psrc2[i02] - max);
103+
}
96104

97105
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
98106
pdst4[i00] /= sum;

0 commit comments

Comments
 (0)