@@ -48,6 +48,7 @@ llama_context::llama_context(
4848 // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
4949 // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
5050 // ref: https://github.com/ggerganov/llama.cpp/pull/5021
51+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
5152 if (cparams.n_batch < GGML_KQ_MASK_PAD) {
5253 LLAMA_LOG_WARN (" %s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n " , __func__, GGML_KQ_MASK_PAD);
5354 cparams.n_batch = GGML_KQ_MASK_PAD;
@@ -2127,60 +2128,44 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
21272128 }
21282129
21292130 if (inp_kq_mask) {
2130- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
21312131 if (cparams.causal_attn ) {
2132- // TODO: need to use the batch directly to construct the masks
2133- GGML_ABORT (" TODO" );
2134-
2135- // const int64_t n_kv = ubatch.n_tokens;
2136- // const int64_t n_tokens = ubatch.n_tokens;
2137- // const int64_t n_seq_tokens = ubatch.n_seq_tokens;
2138- // const int64_t n_seqs = ubatch.n_seqs;
2139-
2140- // float * data = nullptr;
2141-
2142- // if (inp_kq_mask) {
2143- // GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer));
2144- // data = (float *) inp_kq_mask->data;
2145- // }
2146-
2147- // // For causal attention, use only the previous KV cells
2148- // // of the correct sequence for each token of the ubatch.
2149- // // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
2150- // for (int h = 0; h < 1; ++h) {
2151- // for (int s = 0; s < n_seqs; ++s) {
2152- // const llama_seq_id seq_id = ubatch.seq_id[s][0];
2153-
2154- // for (int j = 0; j < n_seq_tokens; ++j) {
2155- // const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
2156-
2157- // for (int i = 0; i < n_kv; ++i) {
2158- // float f;
2159- // if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2160- // f = -INFINITY;
2161- // } else {
2162- // if (hparams.use_alibi) {
2163- // f = -std::abs(kv_self.cells[i].pos - pos);
2164- // } else {
2165- // f = 0.0f;
2166- // }
2167- // }
2168-
2169- // if (data) {
2170- // data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
2171- // }
2172- // }
2173- // }
2174- // }
2175-
2176- // if (data) {
2177- // for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
2178- // for (int j = 0; j < n_kv; ++j) {
2179- // data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2180- // }
2181- // }
2182- // }
2183- // }
2132+ const int64_t n_kv = ubatch.n_tokens ;
2133+ const int64_t n_tokens = ubatch.n_tokens ;
2134+ const int64_t n_seq_tokens = ubatch.n_seq_tokens ;
2135+ const int64_t n_seqs = ubatch.n_seqs ;
2136+
2137+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_kq_mask->buffer ));
2138+ float * data = (float *) inp_kq_mask->data ;
2139+
2140+ for (int h = 0 ; h < 1 ; ++h) {
2141+ for (int s1 = 0 ; s1 < n_seqs; ++s1) {
2142+ const llama_seq_id seq_id = ubatch.seq_id [s1][0 ];
2143+
2144+ for (int j = 0 ; j < n_seq_tokens; ++j) {
2145+ const int32_t tj = s1*n_seq_tokens + j;
2146+
2147+ for (int s0 = 0 ; s0 < n_seqs; ++s0) {
2148+ for (int i = 0 ; i < n_seq_tokens; ++i) {
2149+ const int32_t ti = s0*n_seq_tokens + i;
2150+ float f = -INFINITY;
2151+
2152+ for (int s = 0 ; s < ubatch.n_seq_id [s0]; ++s) {
2153+ if (ubatch.seq_id [s0][s] == seq_id && ubatch.pos [ti] <= ubatch.pos [tj]) {
2154+ if (hparams.use_alibi ) {
2155+ f = -std::abs (ubatch.pos [ti] - ubatch.pos [tj]);
2156+ } else {
2157+ f = 0 .0f ;
2158+ }
2159+ break ;
2160+ }
2161+ }
2162+
2163+ data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
2164+ }
2165+ }
2166+ }
2167+ }
2168+ }
21842169 } else {
21852170 const int64_t n_tokens = ubatch.n_tokens ;
21862171 const int64_t n_seq_tokens = ubatch.n_seq_tokens ;
0 commit comments