Skip to content

Commit 7e45d78

Browse files
committed
remove kv buffer for decode
1 parent a44574e commit 7e45d78

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,6 @@ def _decode(
334334
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
335335
infer_state.mem_is_contiguous = False
336336
infer_state.mem_index = mem_indexes
337-
infer_state.kv_buffer = torch.empty(
338-
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
339-
dtype=self.data_type,
340-
device="cuda",
341-
)
342337
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
343338

344339
infer_state.init_some_extra_state(self, input_ids)

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torc
3535
infer_state.mem_start : infer_state.mem_end, :, :
3636
]
3737
else:
38-
cache_kv = infer_state.kv_buffer
38+
dtype = infer_state.mem_manager.kv_buffer.dtype
39+
cache_kv = self.alloc_tensor(
40+
[infer_state.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_],
41+
dtype=dtype,
42+
device="cuda",
43+
)
3944
return cache_kv
4045

4146
def _get_qkv(

0 commit comments

Comments
 (0)