-
Notifications
You must be signed in to change notification settings - Fork 1
Sink Token Workaround
Attention sink is a phenomenon in autoregressive LLMs where attention is biased toward the initial tokens of the sequence, causing the model to give them more importance than later, more relevant tokens.
This occurs due to the softmax function in attention computation, which requires all attention scores to sum to 1. When many tokens are not strongly relevant to the current context, the model still needs to distribute attention somewhere. As observed in this paper, the initial tokens, being globally visible from any position in the sequence, serve as a "sink" to absorb this excess attention, effectively increasing the denominator in the softmax calculation.
The paper also demonstrates that explicitly adding a dedicated sink token during training leads to more stable performance. Instead of relying on the naturally occurring initial tokens (like BOS or the first word) to absorb excess attention, the model can be trained with a learnable placeholder token specifically designed to serve as an attention sink. This approach stabilizes the attention distribution and improves streaming inference quality, as the model learns to consistently dump irrelevant attention to this dedicated token rather than repurposing semantically meaningful initial tokens.
gpt-oss implements attention sink with a learnable sink parameter S:
QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
QK = torch.cat([QK, S], dim=-1) # Append learnable sink S to attention logits
W = torch.softmax(QK, dim=-1) # S participates in softmax denominator
W = W[..., :-1] # Remove S from weights before matmul with V
attn = torch.einsum("hmqk,khmd->qhmd", W, V)TVM/MLC-LLM implements attention sink by preserving the first N tokens in KV cache:
int32_t offset = i < block.sink_length
? i
: i - block.sink_length + block.sliding_window_offset;The key difference:
- gpt-oss: Virtual sink - a learnable parameter concatenated to attention logits to increase the softmax denominator, then discarded before value projection
- TVM/MLC-LLM: Positional sink - actual tokens stored in KV cache as fixed anchors
This virtual sink approach cannot be implemented directly with TVM's current paged KV cache structure, which only supports preserving physical token positions.
The sink logic in gpt-oss can be implemented by using the sigmoid function and LogSumExp.
Without loss of generality, we focus on a single element of S and denote it as a scalar
For a query
| Denoted by | Definition |
|---|---|
| concatenated |
|
| Attention value for Q, K, and V | |
| Attention value with the sink parameter |
From the code of gpt-oss, we obtain that
for
Furthermore, we can derive some useful form of the scale for implementation:
with the natural log. This is the form of scale with the sigmoid and LogSumExp functions.
We leverage the LSE (LogSumExp) returned by the attention kernel to implement the virtual sink. The kernel returns LSE in log base 2, so we apply a natural log
match forward_to:
case "prefill":
attention, lse_qk = paged_kv_cache.self_attention(self.layer_idx, q=q, k=k, v=v, sm_scale=self.sm_scale)
case "decode":
attention, lse_qk = paged_kv_cache.cross_attention(self.layer_idx, q=q, v_head_dim=d, sm_scale=self.sm_scale)
case "extend":
# For continuing prefill: attend to both current tokens (self) and cached tokens (cross)
o_self, lse_self = paged_kv_cache.self_attention(self.layer_idx, q=q, k=k, v=v, sm_scale=self.sm_scale)
o_cross, lse_cross = paged_kv_cache.cross_attention(self.layer_idx, q=q, v_head_dim=d, sm_scale=self.sm_scale)
attention, lse_qk = paged_kv_cache.merge_attn_output_inplace(o_self, lse_self, o_cross, lse_cross)
case _:
raise ValueError(f"forward_to {forward_to} not supported")
# since TVM use log_2 for compute LSE internally,
# we need to multiply log(2), natural log 2
lse_qk = log(2) * lse_qk
sinks_f32 = nn.op.astype(self.sinks, "float32")
sigmoid_scale = nn.op.sigmoid(lse_qk - sinks_f32)
sigmoid_scale = nn.op.unsqueeze(sigmoid_scale, dim=-1)
sigmoid_scale = nn.op.astype(sigmoid_scale, self.dtype)
attention = nn.op.multiply(attention, sigmoid_scale)- gpt-oss-tvm
- gpt-oss
- Model Card
- Blog post
- GitHub
- [Huggingface] gpt-oss-20b
- [Huggingface] gpt-oss-120b
- TVM
- MLC LLM
-
Attention & Sliding Window
- Computing attentions in TVM
- Sink Token Workaround
-
Mixture-of-Experts (MoE)
- TIR-based MoE Einsum
- Gating Network Implementation
- Comparison with Standard TVM Approaches
-
RoPE with YaRN
- What is YaRN?
- Limitations in Existing TVM Implementations
- Our Improvements
-
TIR-based support for MXFP4
- What is MXFP4?
- MXFP4 TIR Implementation
- Operator Fusion