Skip to content

Commit 50d2b21

Browse files
committed
metal : add comments
1 parent 0629437 commit 50d2b21

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4420,6 +4420,7 @@ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_E
44204420

44214421
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
44224422

4423+
// pad the last chunk of C elements of k and v into a an extra pad buffer
44234424
kernel void kernel_flash_attn_ext_pad(
44244425
constant ggml_metal_kargs_flash_attn_ext_pad & args,
44254426
device const char * k,
@@ -4450,6 +4451,7 @@ kernel void kernel_flash_attn_ext_pad(
44504451
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
44514452

44524453
if (i1 >= icp) {
4454+
// here it is not important the exact value that will be used as we rely on masking out the scores in the attention
44534455
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
44544456
k_dst[i] = 0;
44554457
}
@@ -4663,6 +4665,7 @@ void kernel_flash_attn_ext_impl(
46634665
for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
46644666
int ic = ic0;
46654667

4668+
// the last partial chunk uses the pad buffer as source
46664669
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
46674670
k = pad;
46684671
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
@@ -5390,6 +5393,7 @@ void kernel_flash_attn_ext_vec_impl(
53905393
break;
53915394
}
53925395

5396+
// the last partial chunk uses the pad buffer as source
53935397
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
53945398
k = pad;
53955399
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;

0 commit comments

Comments
 (0)