File tree Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Expand file tree Collapse file tree 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -383,8 +383,16 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
383383 const int64_t pos_i = ubatch->pos [ti];
384384 const int64_t pos_diff = pos_j - pos_i;
385385
386+ // Check both causal attention and symmetric sliding window
387+ bool masked = false ;
388+
389+ // Apply causal attention if enabled (only allow attention to past tokens)
390+ if (cparams.causal_attn && pos_i > pos_j) {
391+ masked = true ;
392+ }
393+
386394 // Apply symmetric sliding window attention logic
387- if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
395+ if (!masked && pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
388396 if (hparams.use_alibi ) {
389397 f = -std::abs (pos_i - pos_j);
390398 } else {
You can’t perform that action at this time.
0 commit comments