File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change @@ -4453,6 +4453,7 @@ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_E
4453
4453
4454
4454
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24 )]];
4455
4455
4456
+ // pad the last chunk of C elements of k and v into a an extra pad buffer
4456
4457
kernel void kernel_flash_attn_ext_pad (
4457
4458
constant ggml_metal_kargs_flash_attn_ext_pad & args,
4458
4459
device const char * k,
@@ -4483,6 +4484,7 @@ kernel void kernel_flash_attn_ext_pad(
4483
4484
device char * v_dst = v_pad + args.nb21 *i1 + args.nb21 *C*i2 + args.nb21 *C*args.ne_12_2 *i3;
4484
4485
4485
4486
if (i1 >= icp) {
4487
+ // here it is not important the exact value that will be used as we rely on masking out the scores in the attention
4486
4488
for (uint64_t i = tiitg; i < args.nb11 ; i += ntg.x ) {
4487
4489
k_dst[i] = 0 ;
4488
4490
}
@@ -4696,6 +4698,7 @@ void kernel_flash_attn_ext_impl(
4696
4698
for (int ic0 = 0 ; ic0 < args.ne11 ; ic0 += C) {
4697
4699
int ic = ic0;
4698
4700
4701
+ // the last partial chunk uses the pad buffer as source
4699
4702
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11 ) {
4700
4703
k = pad;
4701
4704
v = k + args.nb11 *C*args.ne_12_2 *args.ne_12_3 ;
@@ -5423,6 +5426,7 @@ void kernel_flash_attn_ext_vec_impl(
5423
5426
break ;
5424
5427
}
5425
5428
5429
+ // the last partial chunk uses the pad buffer as source
5426
5430
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11 ) {
5427
5431
k = pad;
5428
5432
v = k + args.nb11 *C*args.ne_12_2 *args.ne_12_3 ;
You can’t perform that action at this time.
0 commit comments