@@ -48,6 +48,7 @@ llama_context::llama_context(
48
48
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
49
49
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
50
50
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
51
+ // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
51
52
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
52
53
LLAMA_LOG_WARN (" %s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n " , __func__, GGML_KQ_MASK_PAD);
53
54
cparams.n_batch = GGML_KQ_MASK_PAD;
@@ -2127,60 +2128,44 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
2127
2128
}
2128
2129
2129
2130
if (inp_kq_mask) {
2130
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
2131
2131
if (cparams.causal_attn ) {
2132
- // TODO: need to use the batch directly to construct the masks
2133
- GGML_ABORT (" TODO" );
2134
-
2135
- // const int64_t n_kv = ubatch.n_tokens;
2136
- // const int64_t n_tokens = ubatch.n_tokens;
2137
- // const int64_t n_seq_tokens = ubatch.n_seq_tokens;
2138
- // const int64_t n_seqs = ubatch.n_seqs;
2139
-
2140
- // float * data = nullptr;
2141
-
2142
- // if (inp_kq_mask) {
2143
- // GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer));
2144
- // data = (float *) inp_kq_mask->data;
2145
- // }
2146
-
2147
- // // For causal attention, use only the previous KV cells
2148
- // // of the correct sequence for each token of the ubatch.
2149
- // // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
2150
- // for (int h = 0; h < 1; ++h) {
2151
- // for (int s = 0; s < n_seqs; ++s) {
2152
- // const llama_seq_id seq_id = ubatch.seq_id[s][0];
2153
-
2154
- // for (int j = 0; j < n_seq_tokens; ++j) {
2155
- // const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
2156
-
2157
- // for (int i = 0; i < n_kv; ++i) {
2158
- // float f;
2159
- // if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
2160
- // f = -INFINITY;
2161
- // } else {
2162
- // if (hparams.use_alibi) {
2163
- // f = -std::abs(kv_self.cells[i].pos - pos);
2164
- // } else {
2165
- // f = 0.0f;
2166
- // }
2167
- // }
2168
-
2169
- // if (data) {
2170
- // data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
2171
- // }
2172
- // }
2173
- // }
2174
- // }
2175
-
2176
- // if (data) {
2177
- // for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
2178
- // for (int j = 0; j < n_kv; ++j) {
2179
- // data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2180
- // }
2181
- // }
2182
- // }
2183
- // }
2132
+ const int64_t n_kv = ubatch.n_tokens ;
2133
+ const int64_t n_tokens = ubatch.n_tokens ;
2134
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens ;
2135
+ const int64_t n_seqs = ubatch.n_seqs ;
2136
+
2137
+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_kq_mask->buffer ));
2138
+ float * data = (float *) inp_kq_mask->data ;
2139
+
2140
+ for (int h = 0 ; h < 1 ; ++h) {
2141
+ for (int s1 = 0 ; s1 < n_seqs; ++s1) {
2142
+ const llama_seq_id seq_id = ubatch.seq_id [s1][0 ];
2143
+
2144
+ for (int j = 0 ; j < n_seq_tokens; ++j) {
2145
+ const int32_t tj = s1*n_seq_tokens + j;
2146
+
2147
+ for (int s0 = 0 ; s0 < n_seqs; ++s0) {
2148
+ for (int i = 0 ; i < n_seq_tokens; ++i) {
2149
+ const int32_t ti = s0*n_seq_tokens + i;
2150
+ float f = -INFINITY;
2151
+
2152
+ for (int s = 0 ; s < ubatch.n_seq_id [s0]; ++s) {
2153
+ if (ubatch.seq_id [s0][s] == seq_id && ubatch.pos [ti] <= ubatch.pos [tj]) {
2154
+ if (hparams.use_alibi ) {
2155
+ f = -std::abs (ubatch.pos [ti] - ubatch.pos [tj]);
2156
+ } else {
2157
+ f = 0 .0f ;
2158
+ }
2159
+ break ;
2160
+ }
2161
+ }
2162
+
2163
+ data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
2164
+ }
2165
+ }
2166
+ }
2167
+ }
2168
+ }
2184
2169
} else {
2185
2170
const int64_t n_tokens = ubatch.n_tokens ;
2186
2171
const int64_t n_seq_tokens = ubatch.n_seq_tokens ;
0 commit comments