@@ -265,155 +265,131 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
265265}
266266
267267void llm_graph_input_attn_no_cache::set_input (const llama_ubatch * ubatch) {
268- if (kq_mask) {
269- if (cparams.causal_attn ) {
270- const int64_t n_kv = ubatch->n_tokens ;
271- const int64_t n_tokens = ubatch->n_tokens ;
272- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
273- const int64_t n_seqs = ubatch->n_seqs ;
274-
275- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
276- float * data = (float *) kq_mask->data ;
277-
278- for (int h = 0 ; h < 1 ; ++h) {
279- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
280- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
281-
282- for (int j = 0 ; j < n_seq_tokens; ++j) {
283- const int32_t tj = s1*n_seq_tokens + j;
284-
285- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
286- for (int i = 0 ; i < n_seq_tokens; ++i) {
287- const int32_t ti = s0*n_seq_tokens + i;
288- float f = -INFINITY;
289-
290- // TODO: fix indexing [UBATCH_IDX]
291- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
292- if (ubatch->seq_id [s0][s] == seq_id && ubatch->pos [ti] <= ubatch->pos [tj]) {
293- if (hparams.use_alibi ) {
294- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
295- } else {
296- f = 0 .0f ;
297- }
298- break ;
299- }
300- }
268+ // Helper function for SWA masking logic - mirrors llama_kv_cache_unified::is_masked_swa
269+ auto is_masked_swa = [this ](llama_pos p0, llama_pos p1) -> bool {
270+ assert (p0 >= 0 && p1 >= 0 );
301271
302- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
303- }
304- }
272+ switch (hparams.swa_type ) {
273+ case LLAMA_SWA_TYPE_NONE:
274+ {
275+ } break ;
276+ case LLAMA_SWA_TYPE_STANDARD:
277+ {
278+ if (p1 - p0 >= (int32_t ) hparams.n_swa ) {
279+ return true ;
305280 }
306- }
307- }
308- } else {
309- const int64_t n_tokens = ubatch->n_tokens ;
310- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
311- const int64_t n_seqs = ubatch->n_seqs ;
312- const int64_t n_stride = ubatch->n_tokens ;
313-
314- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
315-
316- float * data = (float *) kq_mask->data ;
317-
318- for (int h = 0 ; h < 1 ; ++h) {
319- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
320- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
321-
322- for (int j = 0 ; j < n_seq_tokens; ++j) {
323- const int32_t tj = s1*n_seq_tokens + j;
324-
325- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
326- for (int i = 0 ; i < n_seq_tokens; ++i) {
327- const int32_t ti = s0*n_seq_tokens + i;
328- float f = -INFINITY;
329-
330- // TODO: fix indexing [UBATCH_IDX]
331- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
332- if (ubatch->seq_id [s0][s] == seq_id) {
333- if (hparams.use_alibi ) {
334- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
335- } else {
336- f = 0 .0f ;
337- }
338- break ;
339- }
340- }
281+ } break ;
282+ case LLAMA_SWA_TYPE_CHUNKED:
283+ {
284+ const llama_pos pos_chunk_start = (p1 / hparams.n_swa ) * hparams.n_swa ;
341285
342- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
343- }
344- }
286+ if (p0 < pos_chunk_start) {
287+ return true ;
288+ }
289+ } break ;
290+ case LLAMA_SWA_TYPE_SYMMETRIC:
291+ {
292+ const int32_t half_n_swa = (int32_t ) hparams.n_swa / 2 ;
293+ const int32_t pos_diff = p1 - p0;
345294
346- for ( int i = n_tokens; i < n_stride; ++i) {
347- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
348- }
295+ // Mask if outside the symmetric window
296+ if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
297+ return true ;
349298 }
350- }
351- }
299+ } break ;
300+ }
301+
302+ return false ;
303+ };
304+
305+ // Helper function for setting attention mask
306+ auto set_mask = [this , ubatch, &is_masked_swa](ggml_tensor * mask, bool apply_swa) {
307+ if (!mask) {
308+ return ;
352309 }
353- }
354310
355- // Handle symmetric SWA mask separately if it exists
356- if (kq_mask_swa) {
357311 const int64_t n_tokens = ubatch->n_tokens ;
358312 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
359313 const int64_t n_seqs = ubatch->n_seqs ;
360- const int64_t n_stride = ubatch->n_tokens ;
361- const int64_t half_n_swa = hparams. n_swa / 2 ;
314+ const int64_t n_kv = ubatch->n_tokens ;
315+ const int64_t n_stride = cparams. causal_attn ? n_kv : n_tokens ;
362316
363- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask_swa->buffer ));
364-
365- float * data = (float *) kq_mask_swa->data ;
317+ GGML_ASSERT (ggml_backend_buffer_is_host (mask->buffer ));
318+ float * data = (float *) mask->data ;
366319
367320 for (int h = 0 ; h < 1 ; ++h) {
368321 for (int s1 = 0 ; s1 < n_seqs; ++s1) {
369322 const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
370323
371324 for (int j = 0 ; j < n_seq_tokens; ++j) {
372325 const int32_t tj = s1*n_seq_tokens + j;
373- const int64_t pos_j = ubatch->pos [tj];
326+ const llama_pos pos_j = ubatch->pos [tj];
374327
375328 for (int s0 = 0 ; s0 < n_seqs; ++s0) {
376329 for (int i = 0 ; i < n_seq_tokens; ++i) {
377330 const int32_t ti = s0*n_seq_tokens + i;
331+ const llama_pos pos_i = ubatch->pos [ti];
378332 float f = -INFINITY;
379333
380- // TODO: fix indexing [UBATCH_IDX]
334+ // Check sequence match
335+ bool sequence_match = false ;
381336 for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
382337 if (ubatch->seq_id [s0][s] == seq_id) {
383- const int64_t pos_i = ubatch->pos [ti];
384- const int64_t pos_diff = pos_j - pos_i;
338+ sequence_match = true ;
339+ break ;
340+ }
341+ }
385342
386- // Check both causal attention and symmetric sliding window
387- bool masked = false ;
343+ if (sequence_match) {
344+ bool masked = false ;
388345
389- // Apply causal attention if enabled (only allow attention to past tokens)
390- if (cparams.causal_attn && pos_i > pos_j) {
391- masked = true ;
392- }
346+ // Apply causal attention if enabled
347+ if (cparams.causal_attn && pos_i > pos_j) {
348+ masked = true ;
349+ }
350+
351+ // Apply SWA masking if needed
352+ if (!masked && apply_swa) {
353+ masked = masked || is_masked_swa (pos_i, pos_j);
354+ }
393355
394- // Apply symmetric sliding window attention logic
395- if (!masked && pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
396- if (hparams.use_alibi ) {
397- f = -std::abs (pos_i - pos_j);
398- } else {
399- f = 0 .0f ;
400- }
356+ if (!masked) {
357+ if (hparams.use_alibi ) {
358+ f = -std::abs (pos_i - pos_j);
359+ } else {
360+ f = 0 .0f ;
401361 }
402- break ;
403362 }
404363 }
405364
406- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
365+ const int idx = h*(n_tokens*n_tokens) + tj*n_stride + ti;
366+ data[idx] = f;
407367 }
408368 }
409369
370+ // Pad the rest of the row with -INFINITY
410371 for (int i = n_tokens; i < n_stride; ++i) {
411- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
372+ const int idx = h*(n_tokens*n_tokens) + tj*n_stride + i;
373+ data[idx] = -INFINITY;
412374 }
413375 }
414376 }
415377 }
416- }
378+
379+ // Pad any remaining entries with -INFINITY
380+ for (int tj = n_tokens; tj < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++tj) {
381+ for (int ti = 0 ; ti < n_stride; ++ti) {
382+ const int idx = 0 *(n_tokens*n_tokens) + tj*n_stride + ti;
383+ data[idx] = -INFINITY;
384+ }
385+ }
386+ };
387+
388+ // Set regular attention mask
389+ set_mask (kq_mask, false );
390+
391+ // Set SWA attention mask if available
392+ set_mask (kq_mask_swa, true );
417393}
418394
419395void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
0 commit comments