Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,8 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
const int64_t n_tps = n_tokens/n_stream;
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);

std::fill(data, data + ggml_nelements(dst), -INFINITY);

// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
Expand All @@ -1306,44 +1308,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub

const llama_pos p1 = ubatch->pos[i];

for (uint32_t j = 0; j < n_kv; ++j) {
float f = 0.0f;

bool masked = false;
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);

for (uint32_t j = 0; j < n_kv; ++j) {
if (cells.is_empty(j)) {
masked = true;
} else {
const llama_pos p0 = cells.pos_get(j);

// mask the token if not the same sequence
masked = masked || (!cells.seq_has(j, seq_id));
continue;
}

// mask future tokens
masked = masked || (causal_attn && p0 > p1);
// mask the token if not the same sequence
if (!cells.seq_has(j, seq_id)) {
continue;
}

// apply SWA if any
masked = masked || (is_masked_swa(p0, p1));
const llama_pos p0 = cells.pos_get(j);

if (!masked && hparams.use_alibi) {
f = -std::abs(p0 - p1);
}
// mask future tokens
if (causal_attn && p0 > p1) {
continue;
}

if (masked) {
f = -INFINITY;
// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
}

data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f;
}

// mask padded tokens
if (data) {
for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) {
for (uint32_t j = 0; j < n_kv; ++j) {
data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY;
}
}
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
}
}
Expand Down
Loading