Skip to content

Sink Token Workaround

Yeongjae Jang edited this page Feb 3, 2026 · 7 revisions

Sink Token Workaround (LSE + Sigmoid)

Background / Problem

What is Attention Sink?

Attention sink is a phenomenon in autoregressive LLMs where attention is biased toward the initial tokens of the sequence, causing the model to give them more importance than later, more relevant tokens.

This occurs due to the softmax function in attention computation, which requires all attention scores to sum to 1. When many tokens are not strongly relevant to the current context, the model still needs to distribute attention somewhere. As observed in this paper, the initial tokens, being globally visible from any position in the sequence, serve as a "sink" to absorb this excess attention, effectively increasing the denominator in the softmax calculation.

The paper also demonstrates that explicitly adding a dedicated sink token during training leads to more stable performance. Instead of relying on the naturally occurring initial tokens (like BOS or the first word) to absorb excess attention, the model can be trained with a learnable placeholder token specifically designed to serve as an attention sink. This approach stabilizes the attention distribution and improves streaming inference quality, as the model learns to consistently dump irrelevant attention to this dedicated token rather than repurposing semantically meaningful initial tokens.

Attention Sink in gpt-oss vs TVM/MLC-LLM

gpt-oss implements attention sink with a learnable sink parameter S:

QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
QK = torch.cat([QK, S], dim=-1)  # Append learnable sink S to attention logits
W = torch.softmax(QK, dim=-1)     # S participates in softmax denominator
W = W[..., :-1]                   # Remove S from weights before matmul with V
attn = torch.einsum("hmqk,khmd->qhmd", W, V)

TVM/MLC-LLM implements attention sink by preserving the first N tokens in KV cache:

int32_t offset = i < block.sink_length 
    ? i 
    : i - block.sink_length + block.sliding_window_offset;

The key difference:

  • gpt-oss: Virtual sink - a learnable parameter concatenated to attention logits to increase the softmax denominator, then discarded before value projection
  • TVM/MLC-LLM: Positional sink - actual tokens stored in KV cache as fixed anchors

This virtual sink approach cannot be implemented directly with TVM's current paged KV cache structure, which only supports preserving physical token positions.

Our Approach

The sink logic in gpt-oss can be implemented by using the sigmoid function and LogSumExp.

Key Point: use LogSumExp

Without loss of generality, we focus on a single element of S and denote it as a scalar $s$.
For a query $Q$, a key $K$, a value $V$, we define

Denoted by Definition
$[{Q^T}{K}, s]$ concatenated ${Q^T}{K}$ with $s$
$A(Q, K, V) = \text{softmax}({Q^T}{K}) \cdot {V}$ Attention value ​​for Q, K, and V
$A(Q, K, V; s) = \text{softmax}([{Q^T}{K}, s]) \cdot {V}$ Attention value with the sink parameter

From the code of gpt-oss, we obtain that

$$ \begin{aligned} A(Q, K, V; s) &:= \text{softmax}([{Q^T}{K}, s]) \cdot {V} \\ &=\frac{1}{\displaystyle\sum_{i}{\exp(a_i)} + \exp(s)} ({\exp(\mathbf{a})} \cdot {V} ) \\ &=\frac{\displaystyle\sum_{i}{\exp(a_i)}}{\displaystyle\sum_{i}{\exp(a_i)} + \exp(s)} \left( \frac{1}{\displaystyle\sum_{i}{\exp(a_i)}} ({\exp(\mathbf{a})} \cdot {V} ) \right) \\ &=\frac{\displaystyle\sum_{i}{\exp(a_i)}}{\displaystyle\sum_{i}{\exp(a_i)} + \exp(s)} \left( \text{softmax}({Q^T}{K}) \cdot {V} \right) \\ &=\frac{\displaystyle\sum_{i}{\exp(a_i)}}{\displaystyle\sum_{i}{\exp(a_i)} + \exp(s)} \left( A(Q, K, V) \right) \end{aligned} $$

for $\mathbf{a} = {{Q^T}{K}}$, Attention functions $A(Q, K, V; s)$ and $A(Q, K, V)$ (with omitted some scales).

Furthermore, we can derive some useful form of the scale for implementation:

$$ \begin{aligned} \frac{\displaystyle\sum_{i}{\exp(a_i)}}{\displaystyle\sum_{i}{\exp(a_i)} + \exp{(s)}} &= \frac{1}{ 1 + \frac{\exp(s)}{\displaystyle\sum_{i}{\exp(a_i)}} } \\ \\ \\ \\ &= \frac{1}{ 1 + \exp{ \left( \log(\exp(s)) - \log({\displaystyle\sum_{i}{\exp(a_i)}}) \right) } } \\ \\ &= \text{sigmoid} \left( \log({\displaystyle\sum_{i}{\exp(a_i)}}) - \log(\exp(s)) \right) \\ \end{aligned} $$

with the natural log. This is the form of scale with the sigmoid and LogSumExp functions.

Implementation

We leverage the LSE (LogSumExp) returned by the attention kernel to implement the virtual sink. The kernel returns LSE in log base 2, so we apply a natural log $\log(2)$ correction factor.

        match forward_to:
            case "prefill":
                attention, lse_qk = paged_kv_cache.self_attention(self.layer_idx, q=q, k=k, v=v, sm_scale=self.sm_scale)
            case "decode":
                attention, lse_qk = paged_kv_cache.cross_attention(self.layer_idx, q=q, v_head_dim=d, sm_scale=self.sm_scale)
            case "extend":
                # For continuing prefill: attend to both current tokens (self) and cached tokens (cross)
                o_self, lse_self = paged_kv_cache.self_attention(self.layer_idx, q=q, k=k, v=v, sm_scale=self.sm_scale)
                o_cross, lse_cross = paged_kv_cache.cross_attention(self.layer_idx, q=q, v_head_dim=d, sm_scale=self.sm_scale)
                attention, lse_qk = paged_kv_cache.merge_attn_output_inplace(o_self, lse_self, o_cross, lse_cross)
            case _:
                raise ValueError(f"forward_to {forward_to} not supported")

        # since TVM use log_2 for compute LSE internally,
        # we need to multiply log(2), natural log 2
        lse_qk = log(2) * lse_qk
        sinks_f32 = nn.op.astype(self.sinks, "float32")
        sigmoid_scale = nn.op.sigmoid(lse_qk - sinks_f32)

        sigmoid_scale = nn.op.unsqueeze(sigmoid_scale, dim=-1)
        sigmoid_scale = nn.op.astype(sigmoid_scale, self.dtype)

        attention = nn.op.multiply(attention, sigmoid_scale)

Getting Started

1. Architectural Implementations

2. Low-Level Optimization

Clone this wiki locally