@@ -279,60 +279,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
279279void llm_graph_input_attn_no_cache::set_input (const llama_ubatch * ubatch) {
280280 if (kq_mask) {
281281 // Check if we're using sliding window attention
282- if (n_swa > 0 ) {
283- const int64_t n_tokens = ubatch->n_tokens ;
284- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
285- const int64_t n_seqs = ubatch->n_seqs ;
286- const int64_t n_stride = ubatch->n_tokens ;
287- const int64_t half_n_swa = n_swa / 2 ;
288-
289- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
290- float * data = (float *) kq_mask->data ;
291-
292- // Implement symmetric sliding window attention
293- // token i attends to tokens [i - n_swa/2, i + n_swa/2]
294- for (int h = 0 ; h < 1 ; ++h) {
295- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
296- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
297-
298- for (int j = 0 ; j < n_seq_tokens; ++j) {
299- const int32_t tj = s1*n_seq_tokens + j;
300- const int64_t pos_j = ubatch->pos [tj];
301-
302- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
303- for (int i = 0 ; i < n_seq_tokens; ++i) {
304- const int32_t ti = s0*n_seq_tokens + i;
305- float f = -INFINITY;
306-
307- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
308- if (ubatch->seq_id [s0][s] == seq_id) {
309- const int64_t pos_i = ubatch->pos [ti];
310- const int64_t pos_diff = pos_j - pos_i;
311-
312- // Apply sliding window constraint
313- // [i - n_swa/2, i + n_swa/2]
314- if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
315- if (hparams.use_alibi ) {
316- f = -std::abs (pos_diff);
317- } else {
318- f = 0 .0f ;
319- }
320- }
321- break ;
322- }
323- }
324-
325- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
326- }
327- }
328-
329- for (int i = n_tokens; i < n_stride; ++i) {
330- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
331- }
332- }
333- }
334- }
335- } else if (cparams.causal_attn ) {
282+ if (cparams.causal_attn ) {
336283 const int64_t n_kv = ubatch->n_tokens ;
337284 const int64_t n_tokens = ubatch->n_tokens ;
338285 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
@@ -375,6 +322,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
375322 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
376323 const int64_t n_seqs = ubatch->n_seqs ;
377324 const int64_t n_stride = ubatch->n_tokens ;
325+ const int64_t half_n_swa = hparams.n_swa / 2 ;
378326
379327 GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
380328
@@ -386,6 +334,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
386334
387335 for (int j = 0 ; j < n_seq_tokens; ++j) {
388336 const int32_t tj = s1*n_seq_tokens + j;
337+ const int64_t pos_j = ubatch->pos [tj];
389338
390339 for (int s0 = 0 ; s0 < n_seqs; ++s0) {
391340 for (int i = 0 ; i < n_seq_tokens; ++i) {
@@ -394,7 +343,11 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
394343
395344 for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
396345 if (ubatch->seq_id [s0][s] == seq_id) {
397- if (hparams.use_alibi ) {
346+ const int64_t pos_i = ubatch->pos [ti];
347+ const int64_t pos_diff = pos_j - pos_i;
348+
349+ if (hparams.use_alibi &&
350+ (pos_diff >= -half_n_swa && pos_diff <= half_n_swa)) {
398351 f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
399352 } else {
400353 f = 0 .0f ;
@@ -1242,22 +1195,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
12421195 return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
12431196}
12441197
1245- llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache_iswa () const {
1246- // Use the sliding window size from hyperparameters
1247- // If hparams.n_swa is 0, use a default value (128)
1248- const int n_swa = hparams.n_swa > 0 ? hparams.n_swa : 128 ;
1249-
1250- auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams, n_swa);
1251-
1252- // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1253- inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1254- ggml_set_input (inp->kq_mask );
1255-
1256- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
1257-
1258- return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
1259- }
1260-
12611198ggml_tensor * llm_graph_context::build_attn (
12621199 llm_graph_input_attn_no_cache * inp,
12631200 ggml_cgraph * gf,
0 commit comments