@@ -252,90 +252,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
252252}
253253
254254void llm_graph_input_attn_no_cache::set_input (const llama_ubatch * ubatch) {
255- // TODO: repace this if with GGML_ASSERT(kq_mask)
256- if (kq_mask) {
257- if (cparams.causal_attn ) {
258- const int64_t n_kv = ubatch->n_tokens ;
259- const int64_t n_tokens = ubatch->n_tokens ;
260- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
261- const int64_t n_seqs = ubatch->n_seqs ;
262-
263- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
264- float * data = (float *) kq_mask->data ;
265-
266- for (int h = 0 ; h < 1 ; ++h) {
267- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
268- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
269-
270- for (int j = 0 ; j < n_seq_tokens; ++j) {
271- const int32_t tj = s1*n_seq_tokens + j;
272-
273- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
274- for (int i = 0 ; i < n_seq_tokens; ++i) {
275- const int32_t ti = s0*n_seq_tokens + i;
276- float f = -INFINITY;
277-
278- // TODO: fix indexing [UBATCH_IDX]
279- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
280- if (ubatch->seq_id [s0][s] == seq_id && ubatch->pos [ti] <= ubatch->pos [tj]) {
281- if (hparams.use_alibi ) {
282- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
283- } else {
284- f = 0 .0f ;
285- }
286- break ;
287- }
288- }
289-
290- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
291- }
292- }
293- }
294- }
295- }
296- } else {
297- const int64_t n_tokens = ubatch->n_tokens ;
298- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
299- const int64_t n_seqs = ubatch->n_seqs ;
300- const int64_t n_stride = ubatch->n_tokens ;
301-
302- GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
303-
304- float * data = (float *) kq_mask->data ;
305-
306- for (int h = 0 ; h < 1 ; ++h) {
307- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
308- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
309-
310- for (int j = 0 ; j < n_seq_tokens; ++j) {
311- const int32_t tj = s1*n_seq_tokens + j;
312-
313- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
314- for (int i = 0 ; i < n_seq_tokens; ++i) {
315- const int32_t ti = s0*n_seq_tokens + i;
316- float f = -INFINITY;
317-
318- // TODO: fix indexing [UBATCH_IDX]
319- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
320- if (ubatch->seq_id [s0][s] == seq_id) {
321- if (hparams.use_alibi ) {
322- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
323- } else {
324- f = 0 .0f ;
325- }
326- break ;
327- }
328- }
329-
330- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
331- }
332- }
255+ const int64_t n_kv = ubatch->n_tokens ;
256+ const int64_t n_tokens = ubatch->n_tokens ;
257+
258+ GGML_ASSERT (kq_mask);
259+ GGML_ASSERT (ggml_backend_buffer_is_host (kq_mask->buffer ));
333260
334- for (int i = n_tokens; i < n_stride; ++i) {
335- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
261+ float * data = (float *) kq_mask->data ;
262+
263+ for (int h = 0 ; h < 1 ; ++h) {
264+ for (int i1 = 0 ; i1 < n_tokens; ++i1) {
265+ const llama_seq_id s1 = ubatch->seq_id [i1][0 ];
266+
267+ for (int i0 = 0 ; i0 < n_tokens; ++i0) {
268+ float f = -INFINITY;
269+
270+ for (int s = 0 ; s < ubatch->n_seq_id [i0]; ++s) {
271+ const llama_seq_id s0 = ubatch->seq_id [i0][0 ];
272+
273+ // TODO: reimplement this like in llama_kv_cache_unified
274+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos [i0] <= ubatch->pos [i1])) {
275+ if (hparams.use_alibi ) {
276+ f = -std::abs (ubatch->pos [i0] - ubatch->pos [i1]);
277+ } else {
278+ f = 0 .0f ;
336279 }
280+ break ;
337281 }
338282 }
283+
284+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
339285 }
340286 }
341287 }
@@ -358,34 +304,36 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
358304}
359305
360306void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
361- if (cross_kq_mask) {
362- const int64_t n_enc = cross_kq_mask->ne [0 ];
363- const int64_t n_tokens = ubatch->n_tokens ;
307+ GGML_ASSERT (cross_kq_mask);
364308
365- GGML_ASSERT ( ggml_backend_buffer_is_host ( cross_kq_mask->buffer )) ;
366- GGML_ASSERT (! ubatch->equal_seqs ); // TODO: use ubatch->n_seqs instead of failing
309+ const int64_t n_enc = cross_kq_mask->ne [ 0 ] ;
310+ const int64_t n_tokens = ubatch->n_tokens ;
367311
368- float * data = (float *) cross_kq_mask->data ;
312+ GGML_ASSERT (ggml_backend_buffer_is_host (cross_kq_mask->buffer ));
313+ GGML_ASSERT (!ubatch->equal_seqs ); // TODO: use ubatch->n_seqs instead of failing
369314
370- for (int h = 0 ; h < 1 ; ++h) {
371- for (int j = 0 ; j < n_tokens; ++j) {
372- for (int i = 0 ; i < n_enc; ++i) {
373- float f = -INFINITY;
374- // TODO: fix indexing [UBATCH_IDX]
375- for (int s = 0 ; s < ubatch->n_seq_id [j]; ++s) {
376- const llama_seq_id seq_id = ubatch->seq_id [j][s];
377- if (cross->seq_ids_enc [i].find (seq_id) != cross->seq_ids_enc [i].end ()) {
378- f = 0 .0f ;
379- }
315+ float * data = (float *) cross_kq_mask->data ;
316+
317+ for (int h = 0 ; h < 1 ; ++h) {
318+ for (int i = 0 ; i < n_tokens; ++i) {
319+ for (int j = 0 ; j < n_enc; ++j) {
320+ float f = -INFINITY;
321+
322+ for (int s = 0 ; s < ubatch->n_seq_id [i]; ++s) {
323+ const llama_seq_id seq_id = ubatch->seq_id [i][s];
324+
325+ if (cross->seq_ids_enc [j].find (seq_id) != cross->seq_ids_enc [j].end ()) {
326+ f = 0 .0f ;
380327 }
381- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
382328 }
329+
330+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
383331 }
332+ }
384333
385- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
386- for (int j = 0 ; j < n_enc; ++j) {
387- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
388- }
334+ for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
335+ for (int j = 0 ; j < n_enc; ++j) {
336+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
389337 }
390338 }
391339 }
0 commit comments