@@ -302,12 +302,10 @@ def _prefill(
302302 infer_state .mem_manager = self .mem_manager
303303 infer_state .req_manager = self .req_manager
304304
305- infer_state .mem_is_contiguous = False
306305 infer_state .mem_index = mem_indexes
307- infer_state .kv_buffer = torch . empty (
306+ infer_state .kv_buffer_shapedtype = (
308307 (input_ids .shape [0 ], self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
309- dtype = self .data_type ,
310- device = "cuda" ,
308+ self .data_type ,
311309 )
312310 infer_state .dist_group = dist_group_manager .get_default_group ()
313311
@@ -351,14 +349,10 @@ def _decode(
351349 infer_state .mem_manager = self .mem_manager
352350 infer_state .req_manager = self .req_manager
353351
354- # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
355- # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
356- infer_state .mem_is_contiguous = False
357352 infer_state .mem_index = mem_indexes
358- infer_state .kv_buffer = torch . empty (
353+ infer_state .kv_buffer_shapedtype = (
359354 (batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
360- dtype = self .data_type ,
361- device = "cuda" ,
355+ self .data_type ,
362356 )
363357 infer_state .dist_group = dist_group_manager .get_default_group ()
364358 copy_kv_index_to_req (self .req_manager .req_to_token_indexs , b_req_idx , b_seq_len , infer_state .mem_index )
@@ -398,14 +392,10 @@ def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
398392 infer_state .mem_manager = self .mem_manager
399393 infer_state .req_manager = self .req_manager
400394
401- # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
402- # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
403- infer_state .mem_is_contiguous = False
404395 infer_state .mem_index = cur_batch .mem_indexes
405- infer_state .kv_buffer = torch . empty (
396+ infer_state .kv_buffer_shapedtype = (
406397 (cur_batch .batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
407- dtype = self .data_type ,
408- device = "cuda" ,
398+ self .data_type ,
409399 )
410400 infer_state .dist_group = dist_group_manager .get_group (batch_index )
411401 copy_kv_index_to_req (
@@ -475,14 +465,10 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
475465 infer_state .mem_manager = self .mem_manager
476466 infer_state .req_manager = self .req_manager
477467
478- # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
479- # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
480- infer_state .mem_is_contiguous = False
481468 infer_state .mem_index = cur_batch .mem_indexes
482- infer_state .kv_buffer = torch . empty (
469+ infer_state .kv_buffer_shapedtype = (
483470 (cur_batch .input_ids .shape [0 ], self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
484- dtype = self .data_type ,
485- device = "cuda" ,
471+ self .data_type ,
486472 )
487473 infer_state .dist_group = dist_group_manager .get_group (batch_index )
488474 init_req_to_token_indexes (
0 commit comments