Skip to content

Commit 550395d

Browse files
committed
fix
1 parent 34ec3a1 commit 550395d

File tree

8 files changed

+170
-66
lines changed

8 files changed

+170
-66
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def alloc_buffer_for_req_kernel(
8+
req_index_ptr, # [num_reqs] - indices of requests to allocate buffers for
9+
buffer_indexes_ptr, # [num_reqs] - buffer indices to assign (from CPU)
10+
req_to_buffer_index_ptr, # [max_request_num + 1] - tensor mapping req_idx to buffer_idx
11+
num_reqs, # number of requests to process
12+
BLOCK_SIZE: tl.constexpr,
13+
):
14+
pid = tl.program_id(0)
15+
block_start = pid * BLOCK_SIZE
16+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
17+
18+
# Mask for valid indices
19+
mask = offsets < num_reqs
20+
21+
# Load request indices and buffer indices
22+
req_indices = tl.load(req_index_ptr + offsets, mask=mask, other=0)
23+
buffer_indices = tl.load(buffer_indexes_ptr + offsets, mask=mask, other=0)
24+
25+
# Update req_to_buffer_index[req_indices] = buffer_indices
26+
tl.store(req_to_buffer_index_ptr + req_indices, buffer_indices, mask=mask)
27+
28+
29+
def alloc_buffer_for_req_triton(
30+
req_index: torch.Tensor, # [num_reqs] int32/int64 tensor on CUDA
31+
buffer_indexes: torch.Tensor, # [num_reqs] int32 tensor (can be CPU or CUDA)
32+
req_to_buffer_index: torch.Tensor, # [max_request_num + 1] int32 tensor on CUDA
33+
):
34+
num_reqs = req_index.shape[0]
35+
36+
# Ensure inputs are on CUDA
37+
if not req_index.is_cuda:
38+
req_index = req_index.cuda()
39+
if not buffer_indexes.is_cuda:
40+
buffer_indexes = buffer_indexes.cuda()
41+
42+
# Ensure correct dtypes
43+
if req_index.dtype not in [torch.int32, torch.int64]:
44+
req_index = req_index.to(torch.int32)
45+
if buffer_indexes.dtype != torch.int32:
46+
buffer_indexes = buffer_indexes.to(torch.int32)
47+
48+
# Launch kernel
49+
BLOCK_SIZE = 256
50+
grid = (triton.cdiv(num_reqs, BLOCK_SIZE),)
51+
52+
alloc_buffer_for_req_kernel[grid](
53+
req_index,
54+
buffer_indexes,
55+
req_to_buffer_index,
56+
num_reqs,
57+
BLOCK_SIZE=BLOCK_SIZE,
58+
)
59+
60+
61+
# Convenience function that matches the original API
62+
def alloc_buffer_for_req_wrapper(
63+
req_manager,
64+
req_index: list,
65+
buffer_indexes: torch.Tensor,
66+
):
67+
"""
68+
Wrapper function to integrate with ReqManagerWithBuffer.
69+
70+
Usage in ReqManagerWithBuffer:
71+
def alloc_buffer_for_req(self, req_index: List[int]):
72+
self.req_has_buffer[req_index] = True
73+
buffer_indexes = self.mem_manager.alloc_buffer(len(req_index)) # cpu tensor
74+
# Replace the next line with Triton kernel
75+
# self.req_to_buffer_index[req_index] = buffer_indexes
76+
from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton
77+
req_index_tensor = torch.tensor(req_index, dtype=torch.int32, device='cuda')
78+
alloc_buffer_for_req_triton(
79+
req_index_tensor,
80+
buffer_indexes,
81+
self.req_has_buffer,
82+
self.req_to_buffer_index
83+
)
84+
"""
85+
req_index_tensor = torch.tensor(req_index, dtype=torch.int32, device="cuda")
86+
alloc_buffer_for_req_triton(
87+
req_index_tensor,
88+
buffer_indexes,
89+
req_manager.req_has_buffer,
90+
req_manager.req_to_buffer_index,
91+
)

lightllm/common/req_manager.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import collections
3+
from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton
34
from lightllm.utils.log_utils import init_logger
45
from .kv_cache_mem_manager import MemoryManager
56
from typing import List, Optional
@@ -243,27 +244,32 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List):
243244
class ReqManagerWithBuffer(ReqManager):
244245
def __init__(self, max_request_num, max_sequence_length, mem_manager):
245246
super().__init__(max_request_num, max_sequence_length, mem_manager)
246-
self.req_has_buffer = torch.zeros((self.max_request_num + 1), dtype=torch.bool, device="cuda")
247247
self.req_to_buffer_index = torch.zeros((self.max_request_num + 1), dtype=torch.int32, device="cuda")
248248
self.req_to_buffer_index[self.HOLD_REQUEST_ID] = self.mem_manager.HOLD_BUFFER_INDEX
249249

250250
@override
251251
def free(self, free_req_indexes: List[int], free_token_index):
252252
super().free(free_req_indexes, free_token_index)
253-
self.req_has_buffer[free_req_indexes] = False
254253
self.free_buffer(self.req_to_buffer_index[free_req_indexes])
255254

256255
@override
257256
def free_all(self):
258-
self.req_has_buffer.zero_()
259257
super().free_all()
260258
return
261259

262260
def free_buffer(self, free_buffer_indexes: List[int]):
263261
self.mem_manager.free_buffer(free_buffer_indexes)
264262
return
265263

266-
def alloc_buffer_for_req(self, req_index: int):
267-
self.req_has_buffer[req_index] = True
268-
buffer_indexes = self.mem_manager.alloc_buffer(len(req_index))
269-
self.req_to_buffer_index[req_index] = buffer_indexes
264+
def alloc_buffer_for_req(self, req_index: torch.Tensor):
265+
buffer_indexes = self.mem_manager.alloc_buffer(req_index.shape[0])
266+
alloc_buffer_for_req_triton(req_index, buffer_indexes, self.req_to_buffer_index)
267+
268+
def reset_buffer(self, req_index: torch.Tensor):
269+
buffer_indexes = self.req_to_buffer_index[req_index]
270+
self.mem_manager.reset_buffer(buffer_indexes)
271+
return
272+
273+
def copy_buffer_from_another_buffer(self, src_buffer_index: int, tgt_req_index: int):
274+
self.mem_manager.copy_buffer(src_buffer_index, self.req_to_buffer_index[tgt_req_index])
275+
return

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _linear_attn(
255255
assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager)
256256

257257
input = input.view(-1, infer_cls.embed_dim_)
258-
buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx]
258+
buffer_idx = infer_state.req_manager.req_to_buffer_index[infer_state.b_req_idx]
259259
conv_states, ssm_states = infer_state.mem_manager.get_buffer(self.layer_idx_)
260260

261261
mixed_qkvzba = layer_weight.linear_in_proj.mm(input)

lightllm/models/qwen3next/mem_manager.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def get_buffer(self, layer_index) -> Tuple[torch.Tensor, torch.Tensor]:
112112
return self.conv_state_mem_manager.buffer[real_layer_index], self.ssm_state_mem_manager.buffer[real_layer_index]
113113

114114
@override
115-
def free_buffer(self, free_buffer_indexes: List[int], reset=True):
115+
def free_buffer(self, free_buffer_indexes: List[int], reset_to_zero=True):
116116
# conv_state 和 ssm_state 共享buffer_idx
117117
self.conv_state_mem_manager.free(free_buffer_indexes)
118-
if reset:
118+
if reset_to_zero:
119119
self.conv_state_mem_manager.buffer[:, free_buffer_indexes] = 0
120120
self.ssm_state_mem_manager.buffer[:, free_buffer_indexes] = 0
121121

@@ -130,8 +130,6 @@ def get_buffer_can_use_size(self):
130130

131131
@override
132132
def copy_buffer(self, src_idx, tgt_idx):
133-
assert src_idx is not None and tgt_idx is not None
134-
assert src_idx != tgt_idx
135133
# Use slice operation and in-place copy for better performance
136134
self.conv_state_mem_manager.buffer[:, tgt_idx].copy_(self.conv_state_mem_manager.buffer[:, src_idx])
137135
self.ssm_state_mem_manager.buffer[:, tgt_idx].copy_(self.ssm_state_mem_manager.buffer[:, src_idx])

lightllm/models/qwen3next/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from lightllm.server.core.objs.start_args_type import StartArgs
1616
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
1717
from lightllm.common.req_manager import ReqManagerWithBuffer
18+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
1819

1920
logger = init_logger(__name__)
2021

@@ -38,7 +39,7 @@ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch
3839
# This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py
3940
triton.set_allocator(_triton_allocator)
4041
logger.info("Triton allocator set for Qwen3Next model")
41-
42+
g_infer_context.use_buffer_manager = True
4243
super().__init__(kvargs)
4344

4445
@override

lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token
5959
node = self.evict_buffer_set.pop(0)
6060
assert node.buffer_idx is not None
6161
evict_buffer_callback(node.buffer_idx)
62+
node.buffer_idx = None
6263
need_evict_buffer_num -= 1
6364
# 当一个节点的buffer_idx变为None时,事实上无法在后续进行match,
6465
# 但当该节点子节点或者引用数不为0时,仍然需要保留, 否则则应该被删除
@@ -73,38 +74,25 @@ def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token
7374
return
7475

7576
def insert_for_hybrid_radix_cache(self, reqs):
76-
# 在请求运行途中对prefix cache进行保留,而不是请求被释放时
7777
from lightllm.server.router.model_infer.infer_batch import g_infer_context
78-
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
7978

80-
# 过滤掉 cur_kv_len 为 0 的请求(新请求还没有生成任何 KV)
81-
valid_reqs = [req for req in reqs if req.cur_kv_len > 0]
79+
self.free_radix_cache_to_get_enough_buffer(len(reqs))
80+
new_buffer_indexes = self.mem_manager.alloc_buffer(len(reqs))
8281

83-
if len(valid_reqs) == 0:
84-
return
82+
for i, req in enumerate(reqs):
83+
input_token_ids = req.get_input_token_ids()
84+
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
85+
value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu()
86+
cur_buffer_idx = g_infer_context.req_manager.req_to_buffer_index[req.req_idx]
8587

86-
# 确保有足够的空间用于新的 buffer,并在锁保护下完成所有 radix cache 操作
87-
g_infer_state_lock.acquire()
88-
try:
89-
self.free_radix_cache_to_get_enough_buffer(len(valid_reqs))
90-
new_buffer_indexes = self.mem_manager.alloc_buffer(len(valid_reqs))
88+
# 分配新的 buffer 并复制当前 buffer 的内容
89+
self.mem_manager.copy_buffer(cur_buffer_idx, new_buffer_indexes[i])
9190

92-
for i, req in enumerate(valid_reqs):
93-
input_token_ids = req.get_input_token_ids()
94-
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
95-
value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu()
96-
cur_buffer_idx = g_infer_context.req_manager.req_to_buffer_indexes[req.req_idx]
97-
98-
# 分配新的 buffer 并复制当前 buffer 的内容
99-
self.mem_manager.copy_buffer(cur_buffer_idx, new_buffer_indexes[i])
100-
101-
_, new_shared_kv_node = super().insert(key, value)
102-
self.dec_node_ref_counter(req.shared_kv_node)
103-
self.add_node_ref_counter(new_shared_kv_node)
104-
new_shared_kv_node.buffer_idx = new_buffer_indexes[i]
105-
req.shared_kv_node = new_shared_kv_node
106-
finally:
107-
g_infer_state_lock.release()
91+
_, new_shared_kv_node = super().insert(key, value)
92+
self.dec_node_ref_counter(req.shared_kv_node)
93+
self.add_node_ref_counter(new_shared_kv_node)
94+
new_shared_kv_node.buffer_idx = new_buffer_indexes[i]
95+
req.shared_kv_node = new_shared_kv_node
10896

10997
def match_prefix(self, key, update_refs=False):
11098
assert len(key) != 0
@@ -184,12 +172,13 @@ def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback):
184172
node: TreeNode = self.evict_tree_set.pop(0)
185173
assert (
186174
node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node
187-
), "error evict tree node state"
175+
), f"error evict tree node state: {node.ref_counter}, {len(node.children)}"
188176
num_evicted += len(node.token_mem_index_value)
189177
evict_callback(node.token_mem_index_value)
190178
if node.buffer_idx is not None:
191179
self.evict_buffer_set.discard(node)
192180
evict_buffer_callback(node.buffer_idx)
181+
node.buffer_idx = None
193182
# update total token num
194183
self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value)
195184
parent_node: TreeNode = node.parent

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,40 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream:
7171
self.cpu_kv_cache_stream = torch.cuda.Stream()
7272
return self.cpu_kv_cache_stream
7373

74+
def _maybe_alloc_and_copy_req_buffers(self, req_objs: List["InferReq"]) -> None:
75+
"""
76+
For hybrid/linear-attention models (e.g. Qwen3-Next) we allocate a fixed-size buffer per request.
77+
If radix cache hits and the matched node has a buffer, copy that buffer content to the newly
78+
allocated buffer for this request.
79+
"""
80+
if not self.use_buffer_manager or not req_objs:
81+
return
82+
83+
if self.radix_cache is not None:
84+
# Ensure enough buffer capacity by evicting radix cache buffers if needed.
85+
self.radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs))
86+
87+
req_idxs = np.array([r.req_idx for r in req_objs], dtype=np.int64)
88+
request_indices_gpu = torch.from_numpy(req_idxs).to(device="cuda", dtype=torch.int64)
89+
self.req_manager.alloc_buffer_for_req(request_indices_gpu)
90+
91+
if self.radix_cache is None:
92+
return
93+
94+
# `shared_kv_node` may be None on cache miss; treat it as "no buffer to copy".
95+
buffer_idxs = np.array(
96+
[None if r.shared_kv_node is None else r.shared_kv_node.buffer_idx for r in req_objs], dtype=object
97+
)
98+
mask = buffer_idxs == None # noqa: E711 (intentional elementwise comparison against None)
99+
copy_indices = req_idxs[~mask].tolist()
100+
if not copy_indices:
101+
return
102+
103+
copy_buffers = buffer_idxs[~mask].tolist()
104+
copy_indices_tensor = torch.tensor(copy_indices, device="cuda", dtype=torch.int64)
105+
copy_buffers_tensor = torch.tensor(copy_buffers, device="cuda", dtype=torch.int64)
106+
self.req_manager.copy_buffer_from_another_buffer(copy_buffers_tensor, copy_indices_tensor)
107+
74108
def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]:
75109
req_objs = []
76110
request_ids = []
@@ -109,19 +143,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
109143
slave_req: InferReq = slave_req
110144
slave_req.related_master_req = master_req
111145

112-
# 线性注意力模型为每个请求申请一块Buffer
113-
if self.use_buffer_manager and len(request_ids) > 0:
114-
if self.radix_cache is not None:
115-
self.radix_cache.free_radix_cache_to_get_enough_buffer(len(request_ids))
116-
self.req_manager.alloc_buffer_for_req(torch.tensor(request_ids, dtype=torch.int64, device="cpu"))
146+
# Hybrid/linear-attention models
147+
self._maybe_alloc_and_copy_req_buffers(req_objs)
117148

118149
return req_objs
119150

120151
def free_a_req_mem(self, free_token_index: List, req: "InferReq", free_buffer_index: List = None):
121152
if self.radix_cache is None:
122153
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
123154
if self.use_buffer_manager:
124-
free_buffer_index.append(self.req_manager.req_to_buffer_indexs[req.req_idx])
155+
free_buffer_index.append(self.req_manager.req_to_buffer_index[req.req_idx])
125156
else:
126157
input_token_ids = req.get_input_token_ids()
127158
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
@@ -131,9 +162,9 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", free_buffer_in
131162
prefix_len, node = self.radix_cache.insert(key, value)
132163
if self.use_buffer_manager:
133164
if node.buffer_idx is None:
134-
node.buffer_idx = self.req_manager.req_to_buffer_indexes[req.req_idx]
165+
node.buffer_idx = self.req_manager.req_to_buffer_index[req.req_idx]
135166
else:
136-
free_buffer_index.append(self.req_manager.req_to_buffer_indexes[req.req_idx])
167+
free_buffer_index.append(self.req_manager.req_to_buffer_index[req.req_idx])
137168

138169
old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
139170
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len])
@@ -179,9 +210,6 @@ def _filter(self, finished_request_ids: List[int]):
179210
free_token_index = custom_cat(free_token_index)
180211
self.req_manager.free(free_req_index, free_token_index)
181212

182-
if self.use_buffer_manager and len(free_buffer_index) != 0:
183-
self.req_manager.free_buffer(free_buffer_index)
184-
185213
finished_req_ids_set = set(finished_request_ids)
186214
self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set]
187215

@@ -208,11 +236,11 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
208236
if pause_reqs:
209237
g_infer_state_lock.acquire()
210238

211-
pause_req_ids = []
239+
pause_req_indices = []
212240
free_token_index = []
213241
free_buffer_index = []
214242
for req in pause_reqs:
215-
pause_req_ids.append(req.req_id)
243+
pause_req_indices.append(req.req_idx)
216244
if self.args.diverse_mode:
217245
# 发生暂停的时候,需要清除 diverse 模式下的主从关系
218246
req.clear_master_slave_state()
@@ -230,8 +258,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
230258
self.req_manager.free_token(free_token_index)
231259

232260
if self.use_buffer_manager and len(free_buffer_index) != 0:
233-
pause_req_ids = torch.tensor(pause_req_ids, dtype=torch.int64, device="cpu")
234-
self.req_manager.req_has_buffer[pause_req_ids] = False
261+
pause_req_indices = torch.tensor(pause_req_indices, dtype=torch.int64, device="cpu")
235262
self.req_manager.free_buffer(free_buffer_index)
236263

237264
g_infer_state_lock.release()
@@ -240,9 +267,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
240267
def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int):
241268
if paused_reqs:
242269
g_infer_state_lock.acquire()
243-
recover_paused_req_ids = []
244270
for req in paused_reqs:
245-
recover_paused_req_ids.append(req.req_id)
246271
prefill_need_token_num = req.get_cur_total_len()
247272
if prefill_need_token_num > can_alloc_token_num:
248273
break
@@ -253,13 +278,7 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo
253278
req.shm_req.is_paused = False
254279
can_alloc_token_num -= prefill_need_token_num
255280

256-
if self.use_buffer_manager and len(recover_paused_req_ids) != 0:
257-
if self.radix_cache is not None:
258-
self.radix_cache.free_radix_cache_to_get_enough_buffer(len(recover_paused_req_ids))
259-
self.req_manager.alloc_buffer_for_req(
260-
torch.tensor(recover_paused_req_ids, dtype=torch.int64, device="cpu")
261-
)
262-
g_infer_state_lock.release()
281+
self._maybe_alloc_and_copy_req_buffers(paused_reqs)
263282
return
264283

265284
def get_can_alloc_token_num(self):

0 commit comments

Comments
 (0)