Skip to content

Commit 51eda92

Browse files
committed
refactor llm_graph_input_attn_no_cache::set_input
1 parent 3387586 commit 51eda92

File tree

1 file changed

+84
-108
lines changed

1 file changed

+84
-108
lines changed

src/llama-graph.cpp

Lines changed: 84 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -265,155 +265,131 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
265265
}
266266

267267
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
268-
if (kq_mask) {
269-
if (cparams.causal_attn) {
270-
const int64_t n_kv = ubatch->n_tokens;
271-
const int64_t n_tokens = ubatch->n_tokens;
272-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
273-
const int64_t n_seqs = ubatch->n_seqs;
274-
275-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
276-
float * data = (float *) kq_mask->data;
277-
278-
for (int h = 0; h < 1; ++h) {
279-
for (int s1 = 0; s1 < n_seqs; ++s1) {
280-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
281-
282-
for (int j = 0; j < n_seq_tokens; ++j) {
283-
const int32_t tj = s1*n_seq_tokens + j;
284-
285-
for (int s0 = 0; s0 < n_seqs; ++s0) {
286-
for (int i = 0; i < n_seq_tokens; ++i) {
287-
const int32_t ti = s0*n_seq_tokens + i;
288-
float f = -INFINITY;
289-
290-
// TODO: fix indexing [UBATCH_IDX]
291-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
292-
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
293-
if (hparams.use_alibi) {
294-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
295-
} else {
296-
f = 0.0f;
297-
}
298-
break;
299-
}
300-
}
268+
// Helper function for SWA masking logic - mirrors llama_kv_cache_unified::is_masked_swa
269+
auto is_masked_swa = [this](llama_pos p0, llama_pos p1) -> bool {
270+
assert(p0 >= 0 && p1 >= 0);
301271

302-
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
303-
}
304-
}
272+
switch (hparams.swa_type) {
273+
case LLAMA_SWA_TYPE_NONE:
274+
{
275+
} break;
276+
case LLAMA_SWA_TYPE_STANDARD:
277+
{
278+
if (p1 - p0 >= (int32_t) hparams.n_swa) {
279+
return true;
305280
}
306-
}
307-
}
308-
} else {
309-
const int64_t n_tokens = ubatch->n_tokens;
310-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
311-
const int64_t n_seqs = ubatch->n_seqs;
312-
const int64_t n_stride = ubatch->n_tokens;
313-
314-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
315-
316-
float * data = (float *) kq_mask->data;
317-
318-
for (int h = 0; h < 1; ++h) {
319-
for (int s1 = 0; s1 < n_seqs; ++s1) {
320-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
321-
322-
for (int j = 0; j < n_seq_tokens; ++j) {
323-
const int32_t tj = s1*n_seq_tokens + j;
324-
325-
for (int s0 = 0; s0 < n_seqs; ++s0) {
326-
for (int i = 0; i < n_seq_tokens; ++i) {
327-
const int32_t ti = s0*n_seq_tokens + i;
328-
float f = -INFINITY;
329-
330-
// TODO: fix indexing [UBATCH_IDX]
331-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
332-
if (ubatch->seq_id[s0][s] == seq_id) {
333-
if (hparams.use_alibi) {
334-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
335-
} else {
336-
f = 0.0f;
337-
}
338-
break;
339-
}
340-
}
281+
} break;
282+
case LLAMA_SWA_TYPE_CHUNKED:
283+
{
284+
const llama_pos pos_chunk_start = (p1 / hparams.n_swa) * hparams.n_swa;
341285

342-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
343-
}
344-
}
286+
if (p0 < pos_chunk_start) {
287+
return true;
288+
}
289+
} break;
290+
case LLAMA_SWA_TYPE_SYMMETRIC:
291+
{
292+
const int32_t half_n_swa = (int32_t) hparams.n_swa / 2;
293+
const int32_t pos_diff = p1 - p0;
345294

346-
for (int i = n_tokens; i < n_stride; ++i) {
347-
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
348-
}
295+
// Mask if outside the symmetric window
296+
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
297+
return true;
349298
}
350-
}
351-
}
299+
} break;
300+
}
301+
302+
return false;
303+
};
304+
305+
// Helper function for setting attention mask
306+
auto set_mask = [this, ubatch, &is_masked_swa](ggml_tensor * mask, bool apply_swa) {
307+
if (!mask) {
308+
return;
352309
}
353-
}
354310

355-
// Handle symmetric SWA mask separately if it exists
356-
if (kq_mask_swa) {
357311
const int64_t n_tokens = ubatch->n_tokens;
358312
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
359313
const int64_t n_seqs = ubatch->n_seqs;
360-
const int64_t n_stride = ubatch->n_tokens;
361-
const int64_t half_n_swa = hparams.n_swa / 2;
314+
const int64_t n_kv = ubatch->n_tokens;
315+
const int64_t n_stride = cparams.causal_attn ? n_kv : n_tokens;
362316

363-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask_swa->buffer));
364-
365-
float * data = (float *) kq_mask_swa->data;
317+
GGML_ASSERT(ggml_backend_buffer_is_host(mask->buffer));
318+
float * data = (float *) mask->data;
366319

367320
for (int h = 0; h < 1; ++h) {
368321
for (int s1 = 0; s1 < n_seqs; ++s1) {
369322
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
370323

371324
for (int j = 0; j < n_seq_tokens; ++j) {
372325
const int32_t tj = s1*n_seq_tokens + j;
373-
const int64_t pos_j = ubatch->pos[tj];
326+
const llama_pos pos_j = ubatch->pos[tj];
374327

375328
for (int s0 = 0; s0 < n_seqs; ++s0) {
376329
for (int i = 0; i < n_seq_tokens; ++i) {
377330
const int32_t ti = s0*n_seq_tokens + i;
331+
const llama_pos pos_i = ubatch->pos[ti];
378332
float f = -INFINITY;
379333

380-
// TODO: fix indexing [UBATCH_IDX]
334+
// Check sequence match
335+
bool sequence_match = false;
381336
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
382337
if (ubatch->seq_id[s0][s] == seq_id) {
383-
const int64_t pos_i = ubatch->pos[ti];
384-
const int64_t pos_diff = pos_j - pos_i;
338+
sequence_match = true;
339+
break;
340+
}
341+
}
385342

386-
// Check both causal attention and symmetric sliding window
387-
bool masked = false;
343+
if (sequence_match) {
344+
bool masked = false;
388345

389-
// Apply causal attention if enabled (only allow attention to past tokens)
390-
if (cparams.causal_attn && pos_i > pos_j) {
391-
masked = true;
392-
}
346+
// Apply causal attention if enabled
347+
if (cparams.causal_attn && pos_i > pos_j) {
348+
masked = true;
349+
}
350+
351+
// Apply SWA masking if needed
352+
if (!masked && apply_swa) {
353+
masked = masked || is_masked_swa(pos_i, pos_j);
354+
}
393355

394-
// Apply symmetric sliding window attention logic
395-
if (!masked && pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
396-
if (hparams.use_alibi) {
397-
f = -std::abs(pos_i - pos_j);
398-
} else {
399-
f = 0.0f;
400-
}
356+
if (!masked) {
357+
if (hparams.use_alibi) {
358+
f = -std::abs(pos_i - pos_j);
359+
} else {
360+
f = 0.0f;
401361
}
402-
break;
403362
}
404363
}
405364

406-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
365+
const int idx = h*(n_tokens*n_tokens) + tj*n_stride + ti;
366+
data[idx] = f;
407367
}
408368
}
409369

370+
// Pad the rest of the row with -INFINITY
410371
for (int i = n_tokens; i < n_stride; ++i) {
411-
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
372+
const int idx = h*(n_tokens*n_tokens) + tj*n_stride + i;
373+
data[idx] = -INFINITY;
412374
}
413375
}
414376
}
415377
}
416-
}
378+
379+
// Pad any remaining entries with -INFINITY
380+
for (int tj = n_tokens; tj < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++tj) {
381+
for (int ti = 0; ti < n_stride; ++ti) {
382+
const int idx = 0*(n_tokens*n_tokens) + tj*n_stride + ti;
383+
data[idx] = -INFINITY;
384+
}
385+
}
386+
};
387+
388+
// Set regular attention mask
389+
set_mask(kq_mask, false);
390+
391+
// Set SWA attention mask if available
392+
set_mask(kq_mask_swa, true);
417393
}
418394

419395
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {

0 commit comments

Comments
 (0)