Skip to content

Commit 9b21358

Browse files
committed
metal : add comments
1 parent ea8f4bb commit 9b21358

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4453,6 +4453,7 @@ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_E
44534453

44544454
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
44554455

4456+
// pad the last chunk of C elements of k and v into a an extra pad buffer
44564457
kernel void kernel_flash_attn_ext_pad(
44574458
constant ggml_metal_kargs_flash_attn_ext_pad & args,
44584459
device const char * k,
@@ -4483,6 +4484,7 @@ kernel void kernel_flash_attn_ext_pad(
44834484
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
44844485

44854486
if (i1 >= icp) {
4487+
// here it is not important the exact value that will be used as we rely on masking out the scores in the attention
44864488
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
44874489
k_dst[i] = 0;
44884490
}
@@ -4696,6 +4698,7 @@ void kernel_flash_attn_ext_impl(
46964698
for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
46974699
int ic = ic0;
46984700

4701+
// the last partial chunk uses the pad buffer as source
46994702
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
47004703
k = pad;
47014704
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
@@ -5423,6 +5426,7 @@ void kernel_flash_attn_ext_vec_impl(
54235426
break;
54245427
}
54255428

5429+
// the last partial chunk uses the pad buffer as source
54265430
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
54275431
k = pad;
54285432
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;

0 commit comments

Comments
 (0)