Skip to content

Commit bd1198a

Browse files
committed
metal : fix build and some more comments
1 parent 5b359bb commit bd1198a

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

ggml/src/ggml-metal.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3046,6 +3046,8 @@ static void ggml_metal_encode_node(
30463046

30473047
bool use_vec_kernel = false;
30483048

3049+
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3050+
// for now avoiding mainly to keep the number of templates/kernels a bit lower
30493051
if (ne01 >= 4 || (ne00%128 != 0)) {
30503052
switch (src1->type) {
30513053
case GGML_TYPE_F16:

ggml/src/ggml-metal.metal

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
33563356
const short D4 = D/4;
33573357
const short D16 = D/16;
33583358
const short NW = N_SIMDWIDTH;
3359-
const short NL = NW/4;
3360-
const short SH = 2*C; // shared memory per simdgroup
3359+
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
3360+
const short SH = 2*C; // shared memory per simdgroup
33613361

33623362
const short T = D + nsg*SH; // shared memory size per query in (half)
33633363

@@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
34483448

34493449
// Q*K^T
34503450
{
3451-
// each simdgroup processes 1 query and 4 keys
3451+
// each simdgroup processes 1 query and 4 (NW/NL) keys
34523452
for (short cc = 0; cc < C/4; ++cc) {
34533453
qk_t mqk = 0.0;
34543454

@@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext_vec(
36453645
half, half4, half4x4, \
36463646
half4x4
36473647

3648-
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
3648+
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
36493649

36503650
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
36513651
#if defined(GGML_METAL_USE_BF16)

0 commit comments

Comments
 (0)