Skip to content

Commit c86beba

Browse files
authored
fix
1 parent bc605a9 commit c86beba

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ def _decode(
334334
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
335335
infer_state.mem_is_contiguous = False
336336
infer_state.mem_index = mem_indexes
337+
infer_state.kv_buffer = torch.empty(
338+
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
339+
dtype=self.data_type,
340+
device="cuda",
341+
)
337342
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
338343

339344
infer_state.init_some_extra_state(self, input_ids)

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,7 @@ def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torc
3535
infer_state.mem_start : infer_state.mem_end, :, :
3636
]
3737
else:
38-
dtype = infer_state.mem_manager.kv_buffer.dtype
39-
cache_kv = self.alloc_tensor(
40-
[infer_state.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_],
41-
dtype=dtype,
42-
device="cuda",
43-
)
38+
cache_kv = infer_state.kv_buffer
4439
return cache_kv
4540

4641
def _get_qkv(

0 commit comments

Comments
 (0)