@@ -303,10 +303,9 @@ def _prefill(
303303 infer_state .req_manager = self .req_manager
304304
305305 infer_state .mem_index = mem_indexes
306- infer_state .kv_buffer = torch . empty (
306+ infer_state .kv_buffer_shapedtype = (
307307 (input_ids .shape [0 ], self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
308- dtype = self .data_type ,
309- device = "cuda" ,
308+ self .data_type ,
310309 )
311310 infer_state .dist_group = dist_group_manager .get_default_group ()
312311
@@ -351,10 +350,9 @@ def _decode(
351350 infer_state .req_manager = self .req_manager
352351
353352 infer_state .mem_index = mem_indexes
354- infer_state .kv_buffer = torch . empty (
353+ infer_state .kv_buffer_shapedtype = (
355354 (batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
356- dtype = self .data_type ,
357- device = "cuda" ,
355+ self .data_type ,
358356 )
359357 infer_state .dist_group = dist_group_manager .get_default_group ()
360358 copy_kv_index_to_req (self .req_manager .req_to_token_indexs , b_req_idx , b_seq_len , infer_state .mem_index )
@@ -395,10 +393,9 @@ def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
395393 infer_state .req_manager = self .req_manager
396394
397395 infer_state .mem_index = cur_batch .mem_indexes
398- infer_state .kv_buffer = torch . empty (
396+ infer_state .kv_buffer_shapedtype = (
399397 (cur_batch .batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
400- dtype = self .data_type ,
401- device = "cuda" ,
398+ self .data_type ,
402399 )
403400 infer_state .dist_group = dist_group_manager .get_group (batch_index )
404401 copy_kv_index_to_req (
@@ -469,10 +466,9 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
469466 infer_state .req_manager = self .req_manager
470467
471468 infer_state .mem_index = cur_batch .mem_indexes
472- infer_state .kv_buffer = torch . empty (
469+ infer_state .kv_buffer_shapedtype = (
473470 (cur_batch .input_ids .shape [0 ], self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
474- dtype = self .data_type ,
475- device = "cuda" ,
471+ self .data_type ,
476472 )
477473 infer_state .dist_group = dist_group_manager .get_group (batch_index )
478474 init_req_to_token_indexes (
0 commit comments