@@ -363,111 +363,6 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
363363 }
364364}
365365
366- void llm_graph_input_attn_no_cache_iswa::set_input (const llama_ubatch * ubatch) {
367- // Standard attention mask
368- if (kq_mask) {
369- if (cparams.causal_attn ) {
370- const int64_t n_kv = ubatch->n_tokens ;
371- const int64_t n_tokens = ubatch->n_tokens ;
372- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
373- const int64_t n_seqs = ubatch->n_seqs ;
374-
375- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
376- float * data = (float *) kq_mask->data ;
377-
378- for (int h = 0 ; h < 1 ; ++h) {
379- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
380- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
381-
382- for (int j = 0 ; j < n_seq_tokens; ++j) {
383- const int32_t tj = s1*n_seq_tokens + j;
384-
385- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
386- for (int i = 0 ; i < n_seq_tokens; ++i) {
387- const int32_t ti = s0 * n_seq_tokens + i;
388- float f = -INFINITY;
389-
390- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
391- if (ubatch->seq_id [s0][s] == seq_id && ubatch->pos [ti] <= ubatch->pos [tj]) {
392- if (hparams.use_alibi ) {
393- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
394- } else {
395- f = 0 .0f ;
396- }
397- break ;
398- }
399- }
400-
401- data[h * (n_kv * n_tokens) + tj * n_kv + ti] = f;
402- }
403- }
404- }
405- }
406-
407- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
408- for (int j = 0 ; j < n_kv; ++j) {
409- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
410- }
411- }
412- }
413- }
414- }
415-
416- // SWA attention mask
417- if (kq_mask_swa) {
418- if (cparams.causal_attn ) {
419- const int64_t n_kv = ubatch->n_tokens ;
420- const int64_t n_tokens = ubatch->n_tokens ;
421- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
422- const int64_t n_seqs = ubatch->n_seqs ;
423- const int64_t window_size = hparams.n_swa ;
424-
425- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask_swa->buffer ));
426- float * data = (float *) kq_mask_swa->data ;
427-
428- for (int h = 0 ; h < 1 ; ++h) {
429- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
430- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
431-
432- for (int j = 0 ; j < n_seq_tokens; ++j) {
433- const int32_t tj = s1*n_seq_tokens + j;
434-
435- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
436- for (int i = 0 ; i < n_seq_tokens; ++i) {
437- const int32_t ti = s0 * n_seq_tokens + i;
438- float f = -INFINITY;
439-
440- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
441- if (ubatch->seq_id [s0][s] == seq_id && ubatch->pos [ti] <= ubatch->pos [tj]) {
442- const bool in_window = (ubatch->pos [tj] - ubatch->pos [ti]) <= window_size;
443-
444- if (in_window) {
445- if (hparams.use_alibi ) {
446- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
447- } else {
448- f = 0 .0f ;
449- }
450- }
451- break ;
452- }
453- }
454-
455- data[h * (n_kv * n_tokens) + tj * n_kv + ti] = f;
456- }
457- }
458- }
459- }
460-
461- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
462- for (int j = 0 ; j < n_kv; ++j) {
463- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
464- }
465- }
466- }
467- }
468- }
469- }
470-
471366void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
472367 if (self_kq_mask) {
473368 kv_state->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
0 commit comments