Skip to content

Commit 29bb062

Browse files
nvchenghaozlucaslie
authored andcommitted
Add sinks / sliding window for Triton backend (#95)
Signed-off-by: nvchenghaoz <[email protected]>
1 parent a8b54f9 commit 29bb062

File tree

5 files changed

+148
-22
lines changed

5 files changed

+148
-22
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/_triton_attention_internal.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def _paged_generate_mha(
100100
n_heads,
101101
d_head,
102102
SEQ_BLOCK_SIZE,
103+
False,
104+
None,
103105
)
104106

105107

@@ -338,6 +340,7 @@ def _generate_mha_rope_fusion(
338340
d_head,
339341
SEQ_BLOCK_SIZE,
340342
HEAD_BLOCK_SIZE,
343+
-1,
341344
)
342345
attention_kv_stage2[(b, n_heads, 1)](
343346
stage1_output_values,
@@ -348,6 +351,8 @@ def _generate_mha_rope_fusion(
348351
n_heads,
349352
d_head,
350353
SEQ_BLOCK_SIZE,
354+
False,
355+
None,
351356
)
352357

353358

@@ -414,7 +419,9 @@ def _flattened_context_mha_rope_fusion(
414419
d_head,
415420
SEQ_BLOCK,
416421
max_cache_seq_len,
417-
num_stages=2,
422+
-1,
423+
False,
424+
None,
418425
)
419426

420427

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def _generate_mha(
4141
input_pos: torch.Tensor,
4242
scale: float,
4343
out: torch.Tensor,
44+
sinks: Optional[torch.Tensor] = None,
45+
sliding_window: Optional[int] = None,
4446
):
4547
b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:]
4648
max_seq_len, n_kv_heads = k_cache.shape[1:3]
@@ -97,7 +99,10 @@ def _generate_mha(
9799
v_d_head,
98100
SEQ_BLOCK_SIZE,
99101
HEAD_BLOCK_SIZE,
102+
sliding_window if sliding_window is not None else -1,
100103
)
104+
has_sinks = sinks is not None
105+
101106
attention_kv_stage2[(b, n_heads, 1)](
102107
stage1_output_values,
103108
stage1_output_logsumexp,
@@ -107,6 +112,8 @@ def _generate_mha(
107112
n_heads,
108113
v_d_head,
109114
SEQ_BLOCK_SIZE,
115+
has_sinks,
116+
sinks,
110117
)
111118

112119

@@ -122,6 +129,8 @@ def _flattened_context_mha(
122129
seq_start: torch.Tensor,
123130
scale: float,
124131
out: torch.Tensor,
132+
sinks: Optional[torch.Tensor] = None,
133+
sliding_window: Optional[int] = None,
125134
) -> None:
126135
# NOTE: s_total == sum(seq_len)
127136
s_total, n_heads, q_d_head = q.shape
@@ -149,6 +158,8 @@ def _flattened_context_mha(
149158

150159
# TODO: use input_pos to get the correct cache locations
151160
grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
161+
has_sinks = sinks is not None
162+
152163
context_attention_kv_flattened[grid](
153164
q,
154165
seq_len,
@@ -165,7 +176,9 @@ def _flattened_context_mha(
165176
v_d_head,
166177
SEQ_BLOCK,
167178
max_cache_seq_len,
168-
num_stages=2,
179+
sliding_window if sliding_window is not None else -1,
180+
has_sinks,
181+
sinks,
169182
)
170183

171184

@@ -187,6 +200,8 @@ def flattened_mha_with_cache(
187200
# <none>
188201
# CONSTANTS
189202
scale: Optional[float],
203+
sinks: Optional[torch.Tensor] = None,
204+
sliding_window: Optional[int] = None,
190205
) -> torch.Tensor:
191206
"""Flattened MHA with cache that takes q, k, v in BSND layout.
192207
@@ -223,7 +238,9 @@ def flattened_mha_with_cache(
223238
y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
224239
if s == 1:
225240
# generate-only phase
226-
_generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, y)
241+
_generate_mha(
242+
q, k, v, k_cache, v_cache, cache_loc, input_pos, scale, y, sinks, sliding_window
243+
)
227244
else:
228245
# mixed context + generate phase
229246
_flattened_context_mha(
@@ -238,6 +255,8 @@ def flattened_mha_with_cache(
238255
seq_start,
239256
scale,
240257
y,
258+
sinks,
259+
sliding_window,
241260
)
242261

243262
return y.view(*output_shape)
@@ -255,6 +274,8 @@ def flattened_mha_fake(
255274
k_cache: torch.Tensor,
256275
v_cache: torch.Tensor,
257276
scale: Optional[float],
277+
sinks: Optional[torch.Tensor] = None,
278+
sliding_window: Optional[int] = None,
258279
):
259280
return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous()
260281

@@ -388,7 +409,11 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
388409
if not isinstance(scale, float):
389410
ad_logger.warning("Provided scale is not a float, Using default scale instead.")
390411
scale = None
391-
412+
# Get sinks and sliding_window from args or kwargs
413+
sinks = extract_op_args(source_attn_node, "sinks")[0]
414+
sliding_window = extract_op_args(source_attn_node, "sliding_window")[0]
392415
return [
393416
scale, # softmax scale
417+
sinks,
418+
sliding_window,
394419
]

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_kernels/attention_with_kv_cache.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def gqa_attention_kv_stage1(
112112
V_D_HEAD: tl.constexpr, # Dimension of each key/value head
113113
SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
114114
HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
115+
SLIDING_WINDOW: tl.constexpr,
115116
):
116117
"""Attention kernel to be used for generate-only batches.
117118
@@ -122,7 +123,7 @@ def gqa_attention_kv_stage1(
122123
Supports non-power-of-2 D_HEAD
123124
124125
Uses flash decoding.
125-
KV-cache layout is assumed to be [Batch,Seq, Head, Dim]
126+
KV-cache layout is assumed to be [Batch, Seq, Head, Dim]
126127
1. Fetch the K-cache from 0 to input_pos
127128
2. Fetch the V-cache from 0 to input_pos
128129
3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
@@ -145,10 +146,20 @@ def gqa_attention_kv_stage1(
145146

146147
# The number of Q heads that map to each KV head.
147148
HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS # This needs to be a power-of-2
148-
if seq_start_pos > kv_position:
149-
return
150-
seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
151-
seq_mask = seq_offsets <= kv_position
149+
150+
# Apply sliding window constraints
151+
if SLIDING_WINDOW > 0:
152+
# For sliding window, limit the sequence range
153+
sliding_start = tl.maximum(0, kv_position - SLIDING_WINDOW + 1)
154+
if seq_start_pos + SEQ_BLOCK_SIZE <= sliding_start or seq_start_pos > kv_position:
155+
return
156+
seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
157+
seq_mask = (seq_offsets <= kv_position) & (seq_offsets >= sliding_start)
158+
else:
159+
if seq_start_pos > kv_position:
160+
return
161+
seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
162+
seq_mask = seq_offsets <= kv_position
152163

153164
# Need to pad the head dim to 16 if HEAD_RATIO is < 16 so that tensor cores can be invoked
154165
#
@@ -358,6 +369,8 @@ def attention_kv_stage2(
358369
N_HEADS: tl.constexpr,
359370
D_HEAD: tl.constexpr,
360371
SEQ_BLOCK_SIZE: tl.constexpr, # Nearest power of 2 for num_blocks
372+
HAS_SINKS: tl.constexpr,
373+
sinks_ptr,
361374
):
362375
# There are batch * N_HEADS programs
363376
batch_id = tl.program_id(axis=0)
@@ -382,6 +395,11 @@ def attention_kv_stage2(
382395
sumexp = tl.exp(logsumexp - max_logsumexp) # [NUM_BLOCKS_POW2]
383396

384397
aggregate_sumexp = tl.sum(sumexp, axis=0)
398+
# Add sinks contribution to the softmax denominator
399+
if HAS_SINKS:
400+
sinks_val = tl.load(sinks_ptr + batch_id * N_HEADS + head_id)
401+
sinks_exp = tl.exp(sinks_val - max_logsumexp)
402+
aggregate_sumexp += sinks_exp
385403

386404
values_offsets = block_offsets[:, None] * D_HEAD + dhead_offsets[None, :]
387405
values_mask = block_mask[:, None] * dhead_mask[None, :]
@@ -573,6 +591,9 @@ def context_attention_kv_flattened(
573591
V_D_HEAD: tl.constexpr, # Dimension of each value head.
574592
SEQ_BLOCK: tl.constexpr,
575593
MAX_SEQ_LENGTH: tl.constexpr,
594+
SLIDING_WINDOW: tl.constexpr, # Sliding window size, -1 means no sliding window
595+
HAS_SINKS: tl.constexpr,
596+
sinks_ptr,
576597
):
577598
"""Kernel for context phase.
578599
@@ -623,7 +644,15 @@ def context_attention_kv_flattened(
623644
# input_pos_ptr stores the location at which kv must be written back for the given batch.
624645
kv_position = tl.load(input_pos_ptr + batch_id)
625646
num_blocks = (kv_position + seq_len + SEQ_BLOCK - 1) // SEQ_BLOCK
626-
for s in range(0, num_blocks + 1, 1):
647+
start = 0
648+
if SLIDING_WINDOW > 0:
649+
# Use the LAST query in this block for more conservative start calculation
650+
last_q_pos = (
651+
(seq_block_id + 1) * SEQ_BLOCK - 1 + kv_position
652+
) # Last query's absolute position
653+
earliest_kv_pos = max(0, last_q_pos - SLIDING_WINDOW + 1)
654+
start = max(0, earliest_kv_pos // SEQ_BLOCK)
655+
for s in range(start, num_blocks + 1):
627656
kv_seq_offsets = s * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
628657
kv_seq_mask = kv_seq_offsets < (kv_position + seq_len)
629658

@@ -637,9 +666,17 @@ def context_attention_kv_flattened(
637666
)
638667
qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
639668
qk += tl.dot(q, k.trans())
640-
qk = tl.where(
641-
(seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :], qk, float("-inf")
642-
)
669+
# Apply causal mask
670+
causal_mask = (seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :]
671+
# Apply sliding window mask if enabled
672+
if SLIDING_WINDOW > 0:
673+
sliding_window_mask = kv_seq_offsets[None, :] >= (
674+
seq_offsets[:, None] + kv_position - SLIDING_WINDOW + 1
675+
)
676+
combined_mask = sliding_window_mask & causal_mask
677+
else:
678+
combined_mask = causal_mask
679+
qk = tl.where(combined_mask, qk, float("-inf"))
643680
qk *= SCALE
644681
# rowmax
645682
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
@@ -662,6 +699,16 @@ def context_attention_kv_flattened(
662699
l_i_new = tl.exp(lse_i - m_ij) + l_ij
663700
lse_i = m_ij + tl.log(l_i_new)
664701

702+
# Add sinks contribution to the final softmax calculation
703+
if HAS_SINKS:
704+
sinks_val = tl.load(sinks_ptr + batch_id * N_HEADS + head_id)
705+
m_sinks = tl.maximum(m_i, sinks_val)
706+
acc_scale = tl.exp(m_i - m_sinks)
707+
acc = acc * acc_scale[:, None]
708+
l_sinks = tl.exp(lse_i - m_sinks) + tl.exp(sinks_val - m_sinks)
709+
lse_i = m_sinks + tl.log(l_sinks)
710+
m_i = m_sinks
711+
665712
o_scale = tl.exp(m_i - lse_i)
666713

667714
acc = acc * o_scale[:, None]

0 commit comments

Comments
 (0)