Skip to content

Commit 94af548

Browse files
committed
metal : fix comments + remove unnecessary addition
ggml-ci
1 parent 9c2b783 commit 94af548

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3834,6 +3834,7 @@ static void ggml_metal_encode_node(
38343834

38353835
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
38363836
// for now avoiding mainly to keep the number of templates/kernels a bit lower
3837+
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
38373838
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
38383839
switch (src1->type) {
38393840
case GGML_TYPE_F16:

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
4848

4949
template <typename type4>
5050
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
51-
reg = (type4)(*(src + il));
51+
reg = (type4)(*(src));
5252
}
5353

5454
#if defined(GGML_METAL_USE_BF16)
@@ -59,7 +59,7 @@ void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & re
5959

6060
template <typename type4>
6161
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
62-
reg = (type4)(*(src + il));
62+
reg = (type4)(*(src));
6363
}
6464
#endif
6565

@@ -3644,7 +3644,7 @@ kernel void kernel_flash_attn_ext_vec(
36443644
const short DK4 = DK/4;
36453645
const short DV4 = DV/4;
36463646
const short NW = N_SIMDWIDTH;
3647-
const short NL = NW/NE; // note: this can be adjusted to support different head sizes simdgroup work loads
3647+
const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
36483648
const short SH = 2*C; // shared memory per simdgroup
36493649

36503650
const short T = DK + nsg*SH; // shared memory size per query in (half)
@@ -3656,7 +3656,7 @@ kernel void kernel_flash_attn_ext_vec(
36563656
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
36573657
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
36583658

3659-
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3659+
// store the result for all queries in local memory (the O matrix from the paper)
36603660
o4_t lo[DV4/NL];
36613661

36623662
// load heads from Q to shared memory
@@ -3756,7 +3756,7 @@ kernel void kernel_flash_attn_ext_vec(
37563756
mqk += dot((float4) mk, (float4) sq4[i]);
37573757
}
37583758

3759-
static_assert(NE > 1, "NE must be > 1");
3759+
static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
37603760

37613761
// simdgroup reduce (NE = 4)
37623762
// [ 0 .. 7] -> [ 0]

src/llama-context.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,11 +2316,6 @@ llama_context * llama_init_from_model(
23162316
params.flash_attn = false;
23172317
}
23182318

2319-
//if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
2320-
// LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
2321-
// params.flash_attn = false;
2322-
//}
2323-
23242319
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
23252320
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
23262321
return nullptr;

0 commit comments

Comments
 (0)