Skip to content

Commit 27cb9e7

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents 51eb148 + e60f241 commit 27cb9e7

File tree

8 files changed

+204
-109
lines changed

8 files changed

+204
-109
lines changed

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
693693
return true;
694694
case GGML_OP_FLASH_ATTN_EXT:
695695
// for new head sizes, add checks here
696-
if (op->src[0]->ne[0] != 40 &&
696+
if (op->src[0]->ne[0] != 32 &&
697+
op->src[0]->ne[0] != 40 &&
697698
op->src[0]->ne[0] != 64 &&
698699
op->src[0]->ne[0] != 80 &&
699700
op->src[0]->ne[0] != 96 &&

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 109 additions & 55 deletions
Large diffs are not rendered by default.

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,8 +2348,13 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
23482348
svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
23492349

23502350
if (opencl_c_version.major >= 3) {
2351+
// Assume it is not available for 3.0, since it is optional in 3.0.
2352+
// If compiling against 3.0, then we can query.
2353+
backend_ctx->non_uniform_workgroups = false;
2354+
#if CL_TARGET_OPENCL_VERSION >= 300
23512355
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool),
23522356
&backend_ctx->non_uniform_workgroups, 0));
2357+
#endif
23532358
} else {
23542359
GGML_ASSERT(opencl_c_version.major == 2);
23552360
// Non-uniform workgroup sizes is mandatory feature in v2.x.

src/llama-graph.cpp

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261
}
262262
}
263263

264-
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264+
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265265
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266-
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267-
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268-
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269-
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
266+
const char * swa_type_str = "unknown";
267+
268+
switch (swa_type) {
269+
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
270+
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
271+
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
272+
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
273+
};
274+
270275
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271276
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272277
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +300,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295300
const int64_t n_kv = ubatch->n_tokens;
296301
const int64_t n_tokens = ubatch->n_tokens;
297302

298-
GGML_ASSERT(kq_mask);
299-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
300-
301-
float * data = (float *) kq_mask->data;
302-
303-
// [TAG_NO_CACHE_ISWA]
304-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
303+
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
304+
for (int h = 0; h < 1; ++h) {
305+
for (int i1 = 0; i1 < n_tokens; ++i1) {
306+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
307+
const llama_pos p1 = ubatch->pos[i1];
305308

306-
for (int h = 0; h < 1; ++h) {
307-
for (int i1 = 0; i1 < n_tokens; ++i1) {
308-
const llama_seq_id s1 = ubatch->seq_id[i1][0];
309+
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
309310

310-
for (int i0 = 0; i0 < n_tokens; ++i0) {
311-
float f = -INFINITY;
312-
313-
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
311+
for (int i0 = 0; i0 < n_tokens; ++i0) {
314312
const llama_seq_id s0 = ubatch->seq_id[i0][0];
313+
const llama_pos p0 = ubatch->pos[i0];
315314

315+
// mask different sequences
316316
if (s0 != s1) {
317-
continue; // skip different sequences
317+
continue;
318318
}
319319

320-
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321-
continue; // skip future tokens for causal attention
320+
// mask future tokens
321+
if (cparams.causal_attn && p0 > p1) {
322+
continue;
322323
}
323324

324-
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325-
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326-
// continue; // skip masked tokens for SWA
327-
//}
328-
329-
// TODO: reimplement this like in llama_kv_cache_unified
330-
if (hparams.use_alibi) {
331-
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332-
} else {
333-
f = 0.0f;
325+
// apply SWA if any
326+
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
327+
continue;
334328
}
329+
330+
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
335331
}
336-
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337332
}
338333
}
334+
};
335+
336+
{
337+
GGML_ASSERT(self_kq_mask);
338+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
339+
340+
float * data = (float *) self_kq_mask->data;
341+
342+
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
343+
344+
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
345+
346+
if (debug) {
347+
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
348+
}
339349
}
340-
if (debug) {
341-
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
350+
351+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
352+
GGML_ASSERT(self_kq_mask_swa);
353+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
354+
355+
float * data = (float *) self_kq_mask_swa->data;
356+
357+
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
358+
359+
fill_mask(data, hparams.n_swa, hparams.swa_type);
360+
361+
if (debug) {
362+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
363+
}
342364
}
343365
}
344366

@@ -1299,12 +1321,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12991321
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
13001322
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
13011323

1302-
const auto n_kv = k->ne[1];
1303-
13041324
ggml_tensor * cur;
13051325

1306-
// TODO: replace hardcoded padding with ggml-provided padding
1307-
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1326+
if (cparams.flash_attn && kq_b == nullptr) {
13081327
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
13091328

13101329
if (v_trans) {
@@ -1419,10 +1438,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14191438
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14201439

14211440
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422-
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1423-
ggml_set_input(inp->kq_mask);
1441+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1442+
ggml_set_input(inp->self_kq_mask);
1443+
1444+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
14241445

1425-
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1446+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1447+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1448+
ggml_set_input(inp->self_kq_mask_swa);
1449+
1450+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1451+
} else {
1452+
inp->self_kq_mask_swa = nullptr;
1453+
inp->self_kq_mask_swa_cnv = nullptr;
1454+
}
14261455

14271456
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
14281457
}
@@ -1447,7 +1476,9 @@ ggml_tensor * llm_graph_context::build_attn(
14471476
ggml_build_forward_expand(gf, k_cur);
14481477
ggml_build_forward_expand(gf, v_cur);
14491478

1450-
const auto & kq_mask = inp->get_kq_mask();
1479+
const bool is_swa = hparams.is_swa(il);
1480+
1481+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14511482

14521483
// [TAG_NO_CACHE_PAD]
14531484
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams

src/llama-graph.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,14 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
257257

258258
void set_input(const llama_ubatch * ubatch) override;
259259

260-
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
260+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
261+
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
261262

262-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
263-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
263+
// n_tokens == n_batch
264+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
265+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
266+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
267+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
264268

265269
const llama_hparams hparams;
266270
const llama_cparams cparams;

src/llama-model.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11368,8 +11368,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
1136811368
}
1136911369
};
1137011370

11371-
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
11372-
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11371+
struct llm_build_gemma_embedding : public llm_graph_context {
11372+
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1137311373
const int64_t n_embd_head = hparams.n_embd_head_k;
1137411374

1137511375
ggml_tensor * cur;
@@ -11386,8 +11386,7 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
1138611386
// inp_pos - contains the positions
1138711387
ggml_tensor * inp_pos = build_inp_pos();
1138811388

11389-
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
11390-
auto * inp_attn = build_attn_inp_kv_iswa();
11389+
auto * inp_attn = build_attn_inp_no_cache();
1139111390

1139211391
ggml_tensor * inp_out_ids = build_inp_out_ids();
1139311392

@@ -19388,7 +19387,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1938819387
case LLM_ARCH_NOMIC_BERT_MOE:
1938919388
case LLM_ARCH_NEO_BERT:
1939019389
case LLM_ARCH_WAVTOKENIZER_DEC:
19391-
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
19390+
case LLM_ARCH_GEMMA_EMBEDDING:
1939219391
case LLM_ARCH_DREAM:
1939319392
case LLM_ARCH_LLADA:
1939419393
case LLM_ARCH_LLADA_MOE:
@@ -19681,7 +19680,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1968119680
} break;
1968219681
case LLM_ARCH_GEMMA_EMBEDDING:
1968319682
{
19684-
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
19683+
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
1968519684
} break;
1968619685
case LLM_ARCH_STARCODER2:
1968719686
{

src/llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ struct llama_model * llama_model_load_from_splits(
312312
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
313313
return nullptr;
314314
}
315+
splits.reserve(n_paths);
315316
for (size_t i = 0; i < n_paths; ++i) {
316317
splits.push_back(paths[i]);
317318
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6779,7 +6779,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67796779
for (int nb : { 1, 3, 32, 35, }) {
67806780
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
67816781
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
6782-
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
6782+
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
67836783
test_cases.emplace_back(new test_flash_attn_ext(
67846784
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
67856785
// run fewer test cases permuted

0 commit comments

Comments
 (0)