@@ -71,6 +71,40 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream:
7171 self .cpu_kv_cache_stream = torch .cuda .Stream ()
7272 return self .cpu_kv_cache_stream
7373
74+ def _maybe_alloc_and_copy_req_buffers (self , req_objs : List ["InferReq" ]) -> None :
75+ """
76+ For hybrid/linear-attention models (e.g. Qwen3-Next) we allocate a fixed-size buffer per request.
77+ If radix cache hits and the matched node has a buffer, copy that buffer content to the newly
78+ allocated buffer for this request.
79+ """
80+ if not self .use_buffer_manager or not req_objs :
81+ return
82+
83+ if self .radix_cache is not None :
84+ # Ensure enough buffer capacity by evicting radix cache buffers if needed.
85+ self .radix_cache .free_radix_cache_to_get_enough_buffer (len (req_objs ))
86+
87+ req_idxs = np .array ([r .req_idx for r in req_objs ], dtype = np .int64 )
88+ request_indices_gpu = torch .from_numpy (req_idxs ).to (device = "cuda" , dtype = torch .int64 )
89+ self .req_manager .alloc_buffer_for_req (request_indices_gpu )
90+
91+ if self .radix_cache is None :
92+ return
93+
94+ # `shared_kv_node` may be None on cache miss; treat it as "no buffer to copy".
95+ buffer_idxs = np .array (
96+ [None if r .shared_kv_node is None else r .shared_kv_node .buffer_idx for r in req_objs ], dtype = object
97+ )
98+ mask = buffer_idxs == None # noqa: E711 (intentional elementwise comparison against None)
99+ copy_indices = req_idxs [~ mask ].tolist ()
100+ if not copy_indices :
101+ return
102+
103+ copy_buffers = buffer_idxs [~ mask ].tolist ()
104+ copy_indices_tensor = torch .tensor (copy_indices , device = "cuda" , dtype = torch .int64 )
105+ copy_buffers_tensor = torch .tensor (copy_buffers , device = "cuda" , dtype = torch .int64 )
106+ self .req_manager .copy_buffer_from_another_buffer (copy_buffers_tensor , copy_indices_tensor )
107+
74108 def add_reqs (self , requests : List [Tuple [int , int , Any , int ]], init_prefix_cache : bool = True ) -> List ["InferReq" ]:
75109 req_objs = []
76110 request_ids = []
@@ -109,19 +143,16 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
109143 slave_req : InferReq = slave_req
110144 slave_req .related_master_req = master_req
111145
112- # 线性注意力模型为每个请求申请一块Buffer
113- if self .use_buffer_manager and len (request_ids ) > 0 :
114- if self .radix_cache is not None :
115- self .radix_cache .free_radix_cache_to_get_enough_buffer (len (request_ids ))
116- self .req_manager .alloc_buffer_for_req (torch .tensor (request_ids , dtype = torch .int64 , device = "cpu" ))
146+ # Hybrid/linear-attention models
147+ self ._maybe_alloc_and_copy_req_buffers (req_objs )
117148
118149 return req_objs
119150
120151 def free_a_req_mem (self , free_token_index : List , req : "InferReq" , free_buffer_index : List = None ):
121152 if self .radix_cache is None :
122153 free_token_index .append (self .req_manager .req_to_token_indexs [req .req_idx ][0 : req .cur_kv_len ])
123154 if self .use_buffer_manager :
124- free_buffer_index .append (self .req_manager .req_to_buffer_indexs [req .req_idx ])
155+ free_buffer_index .append (self .req_manager .req_to_buffer_index [req .req_idx ])
125156 else :
126157 input_token_ids = req .get_input_token_ids ()
127158 key = torch .tensor (input_token_ids [0 : req .cur_kv_len ], dtype = torch .int64 , device = "cpu" )
@@ -131,9 +162,9 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", free_buffer_in
131162 prefix_len , node = self .radix_cache .insert (key , value )
132163 if self .use_buffer_manager :
133164 if node .buffer_idx is None :
134- node .buffer_idx = self .req_manager .req_to_buffer_indexes [req .req_idx ]
165+ node .buffer_idx = self .req_manager .req_to_buffer_index [req .req_idx ]
135166 else :
136- free_buffer_index .append (self .req_manager .req_to_buffer_indexes [req .req_idx ])
167+ free_buffer_index .append (self .req_manager .req_to_buffer_index [req .req_idx ])
137168
138169 old_prefix_len = 0 if req .shared_kv_node is None else req .shared_kv_node .node_prefix_total_len
139170 free_token_index .append (self .req_manager .req_to_token_indexs [req .req_idx ][old_prefix_len :prefix_len ])
@@ -179,9 +210,6 @@ def _filter(self, finished_request_ids: List[int]):
179210 free_token_index = custom_cat (free_token_index )
180211 self .req_manager .free (free_req_index , free_token_index )
181212
182- if self .use_buffer_manager and len (free_buffer_index ) != 0 :
183- self .req_manager .free_buffer (free_buffer_index )
184-
185213 finished_req_ids_set = set (finished_request_ids )
186214 self .infer_req_ids = [_id for _id in self .infer_req_ids if _id not in finished_req_ids_set ]
187215
@@ -208,11 +236,11 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
208236 if pause_reqs :
209237 g_infer_state_lock .acquire ()
210238
211- pause_req_ids = []
239+ pause_req_indices = []
212240 free_token_index = []
213241 free_buffer_index = []
214242 for req in pause_reqs :
215- pause_req_ids .append (req .req_id )
243+ pause_req_indices .append (req .req_idx )
216244 if self .args .diverse_mode :
217245 # 发生暂停的时候,需要清除 diverse 模式下的主从关系
218246 req .clear_master_slave_state ()
@@ -230,8 +258,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
230258 self .req_manager .free_token (free_token_index )
231259
232260 if self .use_buffer_manager and len (free_buffer_index ) != 0 :
233- pause_req_ids = torch .tensor (pause_req_ids , dtype = torch .int64 , device = "cpu" )
234- self .req_manager .req_has_buffer [pause_req_ids ] = False
261+ pause_req_indices = torch .tensor (pause_req_indices , dtype = torch .int64 , device = "cpu" )
235262 self .req_manager .free_buffer (free_buffer_index )
236263
237264 g_infer_state_lock .release ()
@@ -240,9 +267,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
240267 def recover_paused_reqs (self , paused_reqs : List ["InferReq" ], is_master_in_dp : bool , can_alloc_token_num : int ):
241268 if paused_reqs :
242269 g_infer_state_lock .acquire ()
243- recover_paused_req_ids = []
244270 for req in paused_reqs :
245- recover_paused_req_ids .append (req .req_id )
246271 prefill_need_token_num = req .get_cur_total_len ()
247272 if prefill_need_token_num > can_alloc_token_num :
248273 break
@@ -253,13 +278,7 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo
253278 req .shm_req .is_paused = False
254279 can_alloc_token_num -= prefill_need_token_num
255280
256- if self .use_buffer_manager and len (recover_paused_req_ids ) != 0 :
257- if self .radix_cache is not None :
258- self .radix_cache .free_radix_cache_to_get_enough_buffer (len (recover_paused_req_ids ))
259- self .req_manager .alloc_buffer_for_req (
260- torch .tensor (recover_paused_req_ids , dtype = torch .int64 , device = "cpu" )
261- )
262- g_infer_state_lock .release ()
281+ self ._maybe_alloc_and_copy_req_buffers (paused_reqs )
263282 return
264283
265284 def get_can_alloc_token_num (self ):
0 commit comments