Skip to content

Commit d8a1c8a

Browse files
committed
improve llm_graph_input_attn_no_cache::set_input:if(kq_mask_swa) to handle causal_attn
1 parent 72b56c4 commit d8a1c8a

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/llama-graph.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,16 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
383383
const int64_t pos_i = ubatch->pos[ti];
384384
const int64_t pos_diff = pos_j - pos_i;
385385

386+
// Check both causal attention and symmetric sliding window
387+
bool masked = false;
388+
389+
// Apply causal attention if enabled (only allow attention to past tokens)
390+
if (cparams.causal_attn && pos_i > pos_j) {
391+
masked = true;
392+
}
393+
386394
// Apply symmetric sliding window attention logic
387-
if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
395+
if (!masked && pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
388396
if (hparams.use_alibi) {
389397
f = -std::abs(pos_i - pos_j);
390398
} else {

0 commit comments

Comments
 (0)