Skip to content

Commit a1b39c1

Browse files
authored
Perf/fuse mamba state scatter mtp verify (sgl-project#18088)
1 parent ab7071b commit a1b39c1

File tree

3 files changed

+583
-37
lines changed

3 files changed

+583
-37
lines changed

python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
ForwardMetadata,
3333
Mamba2Metadata,
3434
)
35+
from sglang.srt.layers.attention.mamba.mamba_state_scatter_triton import (
36+
fused_mamba_state_scatter_with_mask,
37+
)
3538
from sglang.srt.layers.radix_attention import RadixAttention
3639
from sglang.srt.layers.radix_linear_attention import RadixLinearAttention
3740
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
@@ -1699,16 +1702,22 @@ def update_mamba_state_after_mtp_verify(
16991702
mamba_steps_to_track: Optional[torch.Tensor],
17001703
model,
17011704
):
1705+
"""
1706+
Update mamba states after MTP verify using fully fused Triton kernel.
1707+
1708+
This replaces the original advanced indexing operations with a single fused
1709+
gather-scatter kernel that also handles masking internally, avoiding:
1710+
- index_elementwise_kernel from tensor[bool_mask]
1711+
- index_select kernel launches
1712+
- nonzero kernel launches
1713+
"""
17021714
request_number = accepted_steps.shape[0]
17031715

17041716
state_indices_tensor = (
17051717
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
17061718
:request_number
17071719
]
17081720
)
1709-
intermediate_state_indices = torch.arange(
1710-
request_number, dtype=torch.int32, device=state_indices_tensor.device
1711-
)
17121721

17131722
mamba_caches = (
17141723
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
@@ -1719,41 +1728,34 @@ def update_mamba_state_after_mtp_verify(
17191728
intermediate_state_cache = mamba_caches.intermediate_ssm
17201729
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window[0]
17211730

1722-
# Compute common indices once to avoid duplication
1723-
valid_mask = accepted_steps >= 0
1724-
dst_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
1725-
src_state_indices = intermediate_state_indices[valid_mask].to(
1726-
torch.int64
1727-
) # [N]
1728-
last_steps = accepted_steps[valid_mask].to(torch.int64) # [N]
1729-
1730-
# scatter into ssm_states at the chosen cache lines
1731-
ssm_states[:, dst_state_indices, :] = intermediate_state_cache[
1732-
:, src_state_indices, last_steps
1733-
].to(ssm_states.dtype, copy=False)
1734-
1735-
# Scatter into conv_states at the chosen cache lines
1736-
conv_states[:, dst_state_indices, :] = intermediate_conv_window_cache[
1737-
:, src_state_indices, last_steps
1738-
].to(conv_states.dtype, copy=False)
1731+
# Use fully fused kernel that handles masking internally
1732+
# This avoids separate nonzero() and index_select() calls
1733+
fused_mamba_state_scatter_with_mask(
1734+
ssm_states,
1735+
intermediate_state_cache,
1736+
state_indices_tensor,
1737+
accepted_steps,
1738+
)
1739+
fused_mamba_state_scatter_with_mask(
1740+
conv_states,
1741+
intermediate_conv_window_cache,
1742+
state_indices_tensor,
1743+
accepted_steps,
1744+
)
17391745

17401746
# Track indices used for tracking mamba states for prefix cache
17411747
if mamba_track_indices is not None:
17421748
assert mamba_steps_to_track is not None
1743-
track_mask = mamba_steps_to_track >= 0
1744-
track_steps = mamba_steps_to_track[track_mask].to(torch.int64) # [N]
1745-
if track_steps.numel() == 0:
1746-
# No track indices to update
1747-
return
1748-
dst_track_indices = mamba_track_indices[track_mask].to(torch.int64)
1749-
src_track_indices = intermediate_state_indices[track_mask].to(torch.int64)
1750-
1751-
# scatter into ssm_states at the chosen track states
1752-
ssm_states[:, dst_track_indices, :] = intermediate_state_cache[
1753-
:, src_track_indices, track_steps
1754-
].to(ssm_states.dtype, copy=False)
1755-
1756-
# scatter into conv_states at the chosen track states
1757-
conv_states[:, dst_track_indices, :] = intermediate_conv_window_cache[
1758-
:, src_track_indices, track_steps
1759-
].to(conv_states.dtype, copy=False)
1749+
# Use fully fused kernel for track scatter operations
1750+
fused_mamba_state_scatter_with_mask(
1751+
ssm_states,
1752+
intermediate_state_cache,
1753+
mamba_track_indices,
1754+
mamba_steps_to_track,
1755+
)
1756+
fused_mamba_state_scatter_with_mask(
1757+
conv_states,
1758+
intermediate_conv_window_cache,
1759+
mamba_track_indices,
1760+
mamba_steps_to_track,
1761+
)
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Fused Triton kernel for Mamba state scatter operations.
3+
4+
This kernel replaces the expensive advanced indexing operations in
5+
`update_mamba_state_after_mtp_verify` with a single fused gather-scatter kernel,
6+
avoiding multiple `index_elementwise_kernel` launches.
7+
"""
8+
9+
import torch
10+
import triton
11+
import triton.language as tl
12+
13+
14+
@triton.jit
15+
def _fused_mamba_state_scatter_with_mask_kernel(
16+
src_ptr,
17+
dst_ptr,
18+
# Raw index arrays (before index_select)
19+
dst_indices_raw_ptr, # [total_requests] - state_indices_tensor
20+
step_indices_raw_ptr, # [total_requests] - accepted_steps or mamba_steps_to_track
21+
# Total number of requests
22+
total_requests,
23+
elem_per_entry: tl.constexpr,
24+
src_layer_stride,
25+
src_req_stride,
26+
src_step_stride,
27+
dst_layer_stride,
28+
dst_req_stride,
29+
src_req_size,
30+
src_step_size,
31+
dst_req_size,
32+
BLOCK_SIZE: tl.constexpr,
33+
):
34+
"""
35+
Fused gather-scatter kernel with built-in masking.
36+
37+
This kernel fuses the index_select operations by:
38+
1. Iterating over all requests (pid_req from 0 to total_requests-1)
39+
2. Checking if step_indices_raw[pid_req] >= 0 (valid mask)
40+
3. If valid, performing the scatter:
41+
dst[l, dst_indices_raw[pid_req], :] = src[l, pid_req, step_indices_raw[pid_req], :]
42+
43+
Grid: (total_requests, num_layers, ceil(elem_per_entry / BLOCK_SIZE))
44+
"""
45+
pid_req = tl.program_id(0)
46+
pid_layer = tl.program_id(1).to(tl.int64)
47+
pid_block = tl.program_id(2).to(tl.int64)
48+
49+
# Load step index to check validity (step >= 0 means valid)
50+
step_idx = tl.load(step_indices_raw_ptr + pid_req).to(tl.int64)
51+
52+
# Early exit if this request is not valid (step < 0)
53+
if step_idx < 0:
54+
return
55+
56+
# Load destination index
57+
dst_idx = tl.load(dst_indices_raw_ptr + pid_req).to(tl.int64)
58+
59+
# Source index is just the request index itself
60+
src_idx = pid_req
61+
62+
# Bounds check to avoid illegal memory access
63+
if not (
64+
(dst_idx >= 0)
65+
& (dst_idx < dst_req_size)
66+
& (src_idx >= 0)
67+
& (src_idx < src_req_size)
68+
& (step_idx < src_step_size)
69+
):
70+
return
71+
72+
# Compute base offsets
73+
src_offset = (
74+
pid_layer * src_layer_stride
75+
+ src_idx * src_req_stride
76+
+ step_idx * src_step_stride
77+
)
78+
dst_offset = pid_layer * dst_layer_stride + dst_idx * dst_req_stride
79+
80+
# Compute element range for this block
81+
start = pid_block * BLOCK_SIZE
82+
offsets = start + tl.arange(0, BLOCK_SIZE)
83+
mask = offsets < elem_per_entry
84+
85+
# Load from source and store to destination
86+
data = tl.load(src_ptr + src_offset + offsets, mask=mask)
87+
tl.store(dst_ptr + dst_offset + offsets, data, mask=mask)
88+
89+
90+
def fused_mamba_state_scatter_with_mask(
91+
dst: torch.Tensor, # [num_layers, cache_size, *state_shape]
92+
src: torch.Tensor, # [num_layers, spec_size, draft_tokens, *state_shape]
93+
dst_indices_raw: torch.Tensor, # [total_requests] - raw indices (e.g., state_indices_tensor)
94+
step_indices_raw: torch.Tensor, # [total_requests] - raw step indices (step >= 0 means valid)
95+
):
96+
"""
97+
Fully fused gather-scatter with built-in masking for mamba state updates.
98+
99+
This function fuses the following operations into a single kernel:
100+
1. valid_mask = step_indices_raw >= 0
101+
2. valid_indices = valid_mask.nonzero()
102+
3. dst_indices = dst_indices_raw[valid_indices] (index_select)
103+
4. step_indices = step_indices_raw[valid_indices] (index_select)
104+
5. for each valid i: dst[:, dst_indices[i], :] = src[:, i, step_indices[i], :]
105+
106+
Args:
107+
dst: Destination tensor [num_layers, cache_size, *state_shape]
108+
src: Source tensor [num_layers, spec_size, draft_tokens, *state_shape]
109+
dst_indices_raw: Raw destination indices for all requests [total_requests]
110+
step_indices_raw: Raw step indices; entry >= 0 means valid [total_requests]
111+
"""
112+
total_requests = step_indices_raw.shape[0]
113+
if total_requests == 0:
114+
return
115+
116+
if dst.device != src.device:
117+
raise ValueError(
118+
f"dst and src must be on the same device. {dst.device=} {src.device=}"
119+
)
120+
if not dst.is_cuda or not src.is_cuda:
121+
raise ValueError(
122+
"fused_mamba_state_scatter_with_mask only supports CUDA tensors."
123+
)
124+
if dst.ndim < 2 or src.ndim < 3:
125+
raise ValueError(f"Unexpected tensor ranks: {dst.ndim=} {src.ndim=}")
126+
if dst.shape[0] != src.shape[0]:
127+
raise ValueError(
128+
f"Layer dimension mismatch: {dst.shape[0]=} vs {src.shape[0]=}"
129+
)
130+
if dst.shape[2:] != src.shape[3:]:
131+
raise ValueError(
132+
f"Trailing dims mismatch: {dst.shape[2:]=} vs {src.shape[3:]=}"
133+
)
134+
if dst_indices_raw.ndim != 1 or step_indices_raw.ndim != 1:
135+
raise ValueError(
136+
f"indices must be 1D: {dst_indices_raw.shape=} {step_indices_raw.shape=}"
137+
)
138+
if dst_indices_raw.shape[0] != step_indices_raw.shape[0]:
139+
raise ValueError(
140+
f"indices length mismatch: {dst_indices_raw.shape[0]=} vs {step_indices_raw.shape[0]=}"
141+
)
142+
143+
num_layers = dst.shape[0]
144+
src_req_size = src.shape[1]
145+
src_step_size = src.shape[2]
146+
dst_req_size = dst.shape[1]
147+
148+
# Flatten trailing dimensions: number of elements per (layer, cache_line) entry.
149+
elem_per_entry = dst.numel() // (dst.shape[0] * dst.shape[1])
150+
151+
# Get strides (in elements, not bytes)
152+
src_layer_stride = src.stride(0)
153+
src_req_stride = src.stride(1)
154+
src_step_stride = src.stride(2)
155+
dst_layer_stride = dst.stride(0)
156+
dst_req_stride = dst.stride(1)
157+
158+
# Ensure indices are int32 and contiguous
159+
dst_indices_raw = dst_indices_raw.to(torch.int32).contiguous()
160+
step_indices_raw = step_indices_raw.to(torch.int32).contiguous()
161+
162+
# Ensure tensors are contiguous
163+
if not dst.is_contiguous():
164+
raise ValueError("dst tensor must be contiguous")
165+
if not src.is_contiguous():
166+
raise ValueError("src tensor must be contiguous")
167+
168+
# Block size for copying elements
169+
BLOCK_SIZE = 1024
170+
171+
# Grid over all requests - invalid ones will early-exit in the kernel
172+
grid = (total_requests, num_layers, triton.cdiv(elem_per_entry, BLOCK_SIZE))
173+
174+
_fused_mamba_state_scatter_with_mask_kernel[grid](
175+
src,
176+
dst,
177+
dst_indices_raw,
178+
step_indices_raw,
179+
total_requests,
180+
elem_per_entry,
181+
src_layer_stride,
182+
src_req_stride,
183+
src_step_stride,
184+
dst_layer_stride,
185+
dst_req_stride,
186+
src_req_size,
187+
src_step_size,
188+
dst_req_size,
189+
BLOCK_SIZE=BLOCK_SIZE,
190+
)

0 commit comments

Comments
 (0)