Skip to content

Attention & Sliding Window

Yeongjae Jang edited this page Feb 2, 2026 · 7 revisions

Attention Mechanism

This page covers:

  • TVM's attention calculation structure, its shortcomings, and how to improve it
  • Techniques for applying a pre-trained attention sink

Computing attentions in TVM

SelfAttention & CrossAttention

There are two main methods for calculating Attention in TVM, and depending on the situation, either one or a combination of both can be used.

  • SelfAttention: Attention is computed where Q, K, and V are all derived from the same input sequence.
  • CrossAttention: Attention is computed using Q from the current input sequence, and K and V from the KV cache (i.e., a previously processed sequence).

Note

CrossAttention here can differ from the widely used 'cross attention' in that the sequence from which Q originates and the sequence from which K and V originate ultimately form a single sequence.

SelfAttention is computed over a lower-triangular region, with a causal mask applied to preserve temporal causality. This means that, for any given token, the model can only attend to tokens that appear earlier in the sequence.
In contrast, CrossAttention is computed over a rectangular region, as it attends exclusively to information from a previously processed sequence.

These two attention mechanisms are applied depending on the LLM execution scenario:

  • Initial Prefill: Only SelfAttention is computed for the prefill sequence.

  • Each Decode Step: CrossAttention is applied for each newly generated token, attending to the previously processed sequence.

  • Subsequent Prefill (Extend): SelfAttention is computed for the new prefill sequence, and CrossAttention is computed against the previous sequence. The two results are then merged using their LogSumExp (LSE) values. Alternatively, both computations may be performed in-place into a single output buffer.

Sliding window attention(SWA) in TVM

When you use LLM with TVM & MLC LLM, you can give sliding_window_size in model configs to configure model to use SWA.

And you can apply per-layer SWA by providing a sequence of AttnKind values("mha" or "mha_sliding") ​​instead of just one value when initializing the KV cache.
(You can find this in Gemma 3 model.)

class Gemma3LanguageModel(nn.Module):
    ...

    def create_paged_kv_cache(...)
        return PagedKVCache.create_generic(
            attn_kind=[
                (
                    "mha_sliding"
                    if ((i + 1) % self.config.text_config.sliding_window_pattern)
                    else "mha"
                )
                for i in range(self.num_hidden_layers)
            ],
            ...
        )

Limitations of TVM's Attention computation

Considering sliding window size in SelfAttention

In the current TVM implementation, sliding window attention is supported for CrossAttention, but it is not applied to SelfAttention.

Naively applying a sliding window can lead to incorrect behavior, as previously presented context tokens may be unintentionally excluded. For this reason, models typically adopted sliding window sizes that were relatively large—often comparable to the full context window size. In addition, prefill inputs (e.g., user prompts or tool call outputs) were generally not expected to be long enough to exceed the sliding window boundary. As a result, we guess that the absence of sliding window handling in SelfAttention was rarely an issue in practice.

However, with the methodology of mixing sliding window layers and dense layers (as in gpt-oss), the sliding window issue was significantly alleviated, allowing for smaller sliding window size. At the same time, recent agent-centric LLM use cases can easily produce long prefill sequences that exceed the sliding window size. Under these conditions, applying sliding window constraints to SelfAttention becomes necessary to ensure correct attention behavior.

Hard-coded values in Per-layer SWA

In addition, the per-layer sliding window logic in TVM relies on hard-coded values tailored for Gemma 3, which was the primary target model at the time.

These model-specific constants limit portability to other models with different configurations.

Our improvement

Addressing these limitations required modifications to TVM itself.

To make the changes explicit and easy to track relative to upstream TVM, we applied them as a standalone patch file on top of the original TVM codebase.

We also proposed these changes to the upstream TVM via a pull request.

Applying sliding window in SelfAttention

For CrossAttention, sliding window behavior is implemented by adjusting the range of entries referenced from the KV cache, and this mechanism is already supported in TVM.

In contrast, applying a sliding window to SelfAttention requires extending the causal mask logic. Previously, the causal mask enforced only an upper bound to preserve temporal causality. To support sliding window constraints, a lower bound must also be introduced to exclude tokens that fall outside the sliding window range.

Original implementation

def _causal_mask(causal, row, col, kv_len, qo_len):
    return T.if_then_else(
        causal > 0,
        col < kv_len - qo_len + row + 1,
        col < kv_len,
    )

Modified implementation

def _causal_mask(causal, row, col, kv_len, qo_len, sliding_window_size=-1):
    lower_bound_condition = T.if_then_else(
        sliding_window_size > 0,
        col >= kv_len - qo_len + row - sliding_window_size + 1,
        True,
    )
    return T.if_then_else(
        causal > 0,
        tir.all(col < kv_len - qo_len + row + 1, lower_bound_condition),
        col < kv_len,
    )

Beyond this change, the other modifications involve passing a proper sliding_window_size value to the _causal_mask() function.

Fix to use configurable values

Several hard-coded constants were identified in the per-layer sliding window logic:

  • sliding window size: 1024
  • rotary theta: 10000
  • rotary scale: 1

These values were originally fixed for Gemma 3. We modified the implementation to make them configurable, allowing the same logic to be reused for other target models with different architectural parameters.

The content of this section has been moved to a separate document.

Getting Started

1. Architectural Implementations

2. Low-Level Optimization

Clone this wiki locally