Skip to content

Commit b363c9a

Browse files
committed
code clean up.
1 parent 293b4ee commit b363c9a

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,9 @@ def _prefill(
303303
infer_state.req_manager = self.req_manager
304304

305305
infer_state.mem_index = mem_indexes
306-
infer_state.kv_buffer = torch.empty(
306+
infer_state.kv_buffer_shapedtype = (
307307
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
308-
dtype=self.data_type,
309-
device="cuda",
308+
self.data_type,
310309
)
311310
infer_state.dist_group = dist_group_manager.get_default_group()
312311

@@ -351,10 +350,9 @@ def _decode(
351350
infer_state.req_manager = self.req_manager
352351

353352
infer_state.mem_index = mem_indexes
354-
infer_state.kv_buffer = torch.empty(
353+
infer_state.kv_buffer_shapedtype = (
355354
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
356-
dtype=self.data_type,
357-
device="cuda",
355+
self.data_type,
358356
)
359357
infer_state.dist_group = dist_group_manager.get_default_group()
360358
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
@@ -395,10 +393,9 @@ def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
395393
infer_state.req_manager = self.req_manager
396394

397395
infer_state.mem_index = cur_batch.mem_indexes
398-
infer_state.kv_buffer = torch.empty(
396+
infer_state.kv_buffer_shapedtype = (
399397
(cur_batch.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
400-
dtype=self.data_type,
401-
device="cuda",
398+
self.data_type,
402399
)
403400
infer_state.dist_group = dist_group_manager.get_group(batch_index)
404401
copy_kv_index_to_req(
@@ -469,10 +466,9 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
469466
infer_state.req_manager = self.req_manager
470467

471468
infer_state.mem_index = cur_batch.mem_indexes
472-
infer_state.kv_buffer = torch.empty(
469+
infer_state.kv_buffer_shapedtype = (
473470
(cur_batch.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
474-
dtype=self.data_type,
475-
device="cuda",
471+
self.data_type,
476472
)
477473
infer_state.dist_group = dist_group_manager.get_group(batch_index)
478474
init_req_to_token_indexes(

lightllm/common/basemodel/infer_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self):
2626
self.req_manager: ReqManager = None
2727

2828
self.mem_index = None
29-
self.kv_buffer = None
29+
self.kv_buffer_shapedtype = None
3030

3131
self.is_token_healing = False
3232
self.return_all_prompt_logics = False

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3131
raise Exception("need to impl")
3232

3333
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
34-
cache_kv = infer_state.kv_buffer
34+
cache_kv = self.alloc_tensor(
35+
shape=infer_state.kv_buffer_shapedtype[0],
36+
dtype=infer_state.kv_buffer_shapedtype[1],
37+
device="cuda",
38+
is_graph_out=False,
39+
microbatch_index=infer_state.microbatch_index,
40+
)
3541
return cache_kv
3642

3743
def _get_qkv(

0 commit comments

Comments
 (0)