Skip to content

Commit b539632

Browse files
committed
fix
1 parent 4af3ac5 commit b539632

File tree

7 files changed

+21
-43
lines changed

7 files changed

+21
-43
lines changed

lightllm/common/basemodel/infer_struct.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ def __init__(self):
8888
self.dp_output_split_sizes: List[List[int]] = None
8989
self.dp_input_split_sizes: List[List[int]] = None
9090

91-
# 专门用于管理混合注意力模型的buffer
92-
self.buffer_indexes: torch.Tensor = None
93-
9491
def init_some_extra_state(self, model, input_ids: torch.Tensor):
9592
if self.is_prefill:
9693
(

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _linear_attn(
251251
assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager)
252252

253253
input = input.view(-1, infer_cls.embed_dim_)
254-
buffer_idx = infer_state.buffer_indexes
254+
buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx]
255255
conv_states, ssm_states = infer_state.mem_manager.get_buffer(self.layer_idx_)
256256

257257
mixed_qkvzba = layer_weight.linear_in_proj.mm(input)

lightllm/models/qwen3next/mem_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def free_buffer(self, free_buffer_indexes: List[int], reset=True):
121121
@override
122122
def alloc_buffer(self, need_size):
123123
# conv_state 和 ssm_state 共享buffer_idx
124-
buffer_indexes = self.conv_state_mem_manager.alloc(need_size)
125-
return buffer_indexes
124+
return self.conv_state_mem_manager.alloc(need_size)
126125

127126
@override
128127
def get_buffer_can_use_size(self):

lightllm/models/qwen3next/model.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,6 @@ def _init_mem_manager(self):
8888
mem_fraction=self.mem_fraction,
8989
)
9090

91-
@override
92-
def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0):
93-
infer_state = super()._create_inferstate(model_input, microbatch_index)
94-
95-
buffer_indexes = self.req_manager.req_to_buffer_indexes[model_input.b_req_idx]
96-
infer_state.buffer_indexes = buffer_indexes
97-
return infer_state
98-
9991
@override
10092
def _init_req_manager(self):
10193
create_max_seq_len = 0

lightllm/models/qwen3next/req_manager.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,26 @@ def free_all(self):
2424
super().free_all()
2525
return
2626

27-
def free_buffer(self, free_req_indexes: List[int]):
28-
from lightllm.server.router.model_infer.infer_batch import g_infer_context
29-
30-
if g_infer_context.radix_cache is None:
31-
self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes])
32-
self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX
33-
return
34-
35-
def alloc_buffer(self, req_indexes: List[int]):
27+
@override
28+
def alloc(self):
3629
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
3730
from lightllm.server.router.model_infer.infer_batch import g_infer_context
3831

39-
cur_buffer_indexes = self.req_to_buffer_indexes[req_indexes]
40-
empty_indexes = cur_buffer_indexes == self.EMPTY_BUFFER_INDEX
41-
num_empty = empty_indexes.sum()
42-
if num_empty == 0:
43-
return
32+
req_index = super().alloc()
4433

4534
g_infer_state_lock.acquire()
4635
if g_infer_context.radix_cache is not None:
47-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_empty)
48-
new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda()
36+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(1)
37+
new_buffer_index = self.mem_manager.alloc_buffer(1)
38+
self.req_to_buffer_indexes[req_index] = new_buffer_index
4939
g_infer_state_lock.release()
5040

51-
cur_buffer_indexes[empty_indexes] = new_buffer_indexes
52-
self.req_to_buffer_indexes[req_indexes] = cur_buffer_indexes
41+
return req_index
42+
43+
def free_buffer(self, free_req_indexes: List[int]):
44+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
45+
46+
if g_infer_context.radix_cache is None:
47+
self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes])
48+
self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX
5349
return

lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,12 @@ def insert_for_hybrid_radix_cache(self, reqs):
8282
input_token_ids = req.get_input_token_ids()
8383
key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
8484
value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu()
85-
buffer_idx = req.buffer_idx
8685

8786
# 分配新的 buffer 并复制当前 buffer 的内容
88-
self.mem_manager.copy_buffer(buffer_idx, new_buffer_indexes[i])
89-
req.buffer_idx = new_buffer_indexes[i]
87+
self.mem_manager.copy_buffer(req.buffer_idx, new_buffer_indexes[i])
9088

9189
_, new_shared_kv_node = self.insert(key, value)
92-
new_shared_kv_node.buffer_idx = buffer_idx
90+
new_shared_kv_node.buffer_idx = new_buffer_indexes[i]
9391
self.dec_node_ref_counter(req.shared_kv_node)
9492
self.add_node_ref_counter(new_shared_kv_node)
9593
req.shared_kv_node = new_shared_kv_node

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,6 @@ def prefill_normal(
111111
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
112112
)
113113

114-
if hasattr(g_infer_context.req_manager, "req_to_buffer_indexes"):
115-
g_infer_context.req_manager.alloc_buffer(model_input.b_req_idx)
116-
117114
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
118115
model_output = self.model.forward(model_input)
119116
_, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
@@ -132,6 +129,9 @@ def prefill_normal(
132129
event_pack.notify_post_handle_and_wait_pre_post_handle()
133130
update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
134131

132+
if isinstance(g_infer_context.radix_cache, HybridRadixCache):
133+
g_infer_context.radix_cache.insert_for_hybrid_radix_cache(run_reqs)
134+
135135
# 第三阶段
136136
event_pack.notify_forward_and_wait_post_handle()
137137
sync_event.synchronize()
@@ -143,10 +143,6 @@ def prefill_normal(
143143
extra_post_req_handle_func=self.extra_post_req_handle_func,
144144
nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func,
145145
)
146-
147-
if isinstance(g_infer_context.radix_cache, HybridRadixCache):
148-
g_infer_context.radix_cache.insert_for_hybrid_radix_cache(run_reqs)
149-
150146
# 第四阶段
151147
event_pack.notify_pre_post_handle()
152148
return

0 commit comments

Comments
 (0)