Skip to content

Commit 13cb2f9

Browse files
committed
Revert "graph : support cacheless embeddings with FA and iSWA (ggml-org#16528)"
This reverts commit e38b7c6.
1 parent 56c5cc3 commit 13cb2f9

File tree

4 files changed

+51
-87
lines changed

4 files changed

+51
-87
lines changed

src/llama-graph.cpp

Lines changed: 42 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,12 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261261
}
262262
}
263263

264-
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
264+
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265265
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266-
const char * swa_type_str = "unknown";
267-
268-
switch (swa_type) {
269-
case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
270-
case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
271-
case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
272-
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
273-
};
274-
266+
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267+
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268+
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269+
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
275270
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
276271
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
277272
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -300,67 +295,50 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
300295
const int64_t n_kv = ubatch->n_tokens;
301296
const int64_t n_tokens = ubatch->n_tokens;
302297

303-
const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
304-
for (int h = 0; h < 1; ++h) {
305-
for (int i1 = 0; i1 < n_tokens; ++i1) {
306-
const llama_seq_id s1 = ubatch->seq_id[i1][0];
307-
const llama_pos p1 = ubatch->pos[i1];
298+
GGML_ASSERT(kq_mask);
299+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
300+
301+
float * data = (float *) kq_mask->data;
308302

309-
const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
303+
// [TAG_NO_CACHE_ISWA]
304+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
310305

311-
for (int i0 = 0; i0 < n_tokens; ++i0) {
306+
for (int h = 0; h < 1; ++h) {
307+
for (int i1 = 0; i1 < n_tokens; ++i1) {
308+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
309+
310+
for (int i0 = 0; i0 < n_tokens; ++i0) {
311+
float f = -INFINITY;
312+
313+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
312314
const llama_seq_id s0 = ubatch->seq_id[i0][0];
313-
const llama_pos p0 = ubatch->pos[i0];
314315

315-
// mask different sequences
316316
if (s0 != s1) {
317-
continue;
317+
continue; // skip different sequences
318318
}
319319

320-
// mask future tokens
321-
if (cparams.causal_attn && p0 > p1) {
322-
continue;
320+
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321+
continue; // skip future tokens for causal attention
323322
}
324323

325-
// apply SWA if any
326-
if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
327-
continue;
328-
}
324+
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325+
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326+
// continue; // skip masked tokens for SWA
327+
//}
329328

330-
data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
329+
// TODO: reimplement this like in llama_kv_cache_unified
330+
if (hparams.use_alibi) {
331+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332+
} else {
333+
f = 0.0f;
334+
}
331335
}
336+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
332337
}
333338
}
334-
};
335-
336-
{
337-
GGML_ASSERT(self_kq_mask);
338-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
339-
340-
float * data = (float *) self_kq_mask->data;
341-
342-
std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
343-
344-
fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
345-
346-
if (debug) {
347-
print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
348-
}
349339
}
350-
351-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
352-
GGML_ASSERT(self_kq_mask_swa);
353-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
354-
355-
float * data = (float *) self_kq_mask_swa->data;
356-
357-
std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
358-
359-
fill_mask(data, hparams.n_swa, hparams.swa_type);
360-
361-
if (debug) {
362-
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
363-
}
340+
if (debug) {
341+
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
364342
}
365343
}
366344

@@ -1357,10 +1335,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
13571335
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
13581336
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
13591337

1338+
const auto n_kv = k->ne[1];
1339+
13601340
ggml_tensor * cur;
13611341

13621342
// TODO: replace hardcoded padding with ggml-provided padding
1363-
if (cparams.flash_attn && kq_b == nullptr) {
1343+
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
13641344
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
13651345

13661346
if (v_trans) {
@@ -1475,20 +1455,10 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
14751455
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
14761456

14771457
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1478-
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1479-
ggml_set_input(inp->self_kq_mask);
1480-
1481-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1458+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1459+
ggml_set_input(inp->kq_mask);
14821460

1483-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1484-
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1485-
ggml_set_input(inp->self_kq_mask_swa);
1486-
1487-
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1488-
} else {
1489-
inp->self_kq_mask_swa = nullptr;
1490-
inp->self_kq_mask_swa_cnv = nullptr;
1491-
}
1461+
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
14921462

14931463
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
14941464
}
@@ -1513,9 +1483,7 @@ ggml_tensor * llm_graph_context::build_attn(
15131483
ggml_build_forward_expand(gf, k_cur);
15141484
ggml_build_forward_expand(gf, v_cur);
15151485

1516-
const bool is_swa = hparams.is_swa(il);
1517-
1518-
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1486+
const auto & kq_mask = inp->get_kq_mask();
15191487

15201488
// [TAG_NO_CACHE_PAD]
15211489
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams

src/llama-graph.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,10 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
257257

258258
void set_input(const llama_ubatch * ubatch) override;
259259

260-
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
261-
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
260+
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
262261

263-
// n_tokens == n_batch
264-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
265-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
266-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream]
267-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream]
262+
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
263+
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
268264

269265
const llama_hparams hparams;
270266
const llama_cparams cparams;

src/llama-model.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11531,8 +11531,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
1153111531
}
1153211532
};
1153311533

11534-
struct llm_build_gemma_embedding : public llm_graph_context {
11535-
llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
11534+
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
11535+
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1153611536
const int64_t n_embd_head = hparams.n_embd_head_k;
1153711537

1153811538
ggml_tensor * cur;
@@ -11549,7 +11549,8 @@ struct llm_build_gemma_embedding : public llm_graph_context {
1154911549
// inp_pos - contains the positions
1155011550
ggml_tensor * inp_pos = build_inp_pos();
1155111551

11552-
auto * inp_attn = build_attn_inp_no_cache();
11552+
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
11553+
auto * inp_attn = build_attn_inp_kv_iswa();
1155311554

1155411555
ggml_tensor * inp_out_ids = build_inp_out_ids();
1155511556

@@ -19694,7 +19695,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1969419695
case LLM_ARCH_NOMIC_BERT_MOE:
1969519696
case LLM_ARCH_NEO_BERT:
1969619697
case LLM_ARCH_WAVTOKENIZER_DEC:
19697-
case LLM_ARCH_GEMMA_EMBEDDING:
19698+
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
1969819699
case LLM_ARCH_DREAM:
1969919700
case LLM_ARCH_LLADA:
1970019701
case LLM_ARCH_LLADA_MOE:
@@ -19987,7 +19988,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1998719988
} break;
1998819989
case LLM_ARCH_GEMMA_EMBEDDING:
1998919990
{
19990-
llm = std::make_unique<llm_build_gemma_embedding>(*this, params);
19991+
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
1999119992
} break;
1999219993
case LLM_ARCH_STARCODER2:
1999319994
{

src/llama.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ struct llama_model * llama_model_load_from_splits(
346346
LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
347347
return nullptr;
348348
}
349-
splits.reserve(n_paths);
350349
for (size_t i = 0; i < n_paths; ++i) {
351350
splits.push_back(paths[i]);
352351
}

0 commit comments

Comments
 (0)