Skip to content

Commit abe25a3

Browse files
committed
fix
1 parent ad85b56 commit abe25a3

File tree

3 files changed

+1
-19
lines changed

3 files changed

+1
-19
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ def _prefill(
302302
infer_state.mem_manager = self.mem_manager
303303
infer_state.req_manager = self.req_manager
304304

305-
infer_state.mem_is_contiguous = False
306305
infer_state.mem_index = mem_indexes
307306
infer_state.kv_buffer = torch.empty(
308307
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
@@ -351,9 +350,6 @@ def _decode(
351350
infer_state.mem_manager = self.mem_manager
352351
infer_state.req_manager = self.req_manager
353352

354-
# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
355-
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
356-
infer_state.mem_is_contiguous = False
357353
infer_state.mem_index = mem_indexes
358354
infer_state.kv_buffer = torch.empty(
359355
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
@@ -398,9 +394,6 @@ def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
398394
infer_state.mem_manager = self.mem_manager
399395
infer_state.req_manager = self.req_manager
400396

401-
# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
402-
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
403-
infer_state.mem_is_contiguous = False
404397
infer_state.mem_index = cur_batch.mem_indexes
405398
infer_state.kv_buffer = torch.empty(
406399
(cur_batch.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
@@ -475,9 +468,6 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
475468
infer_state.mem_manager = self.mem_manager
476469
infer_state.req_manager = self.req_manager
477470

478-
# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
479-
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
480-
infer_state.mem_is_contiguous = False
481471
infer_state.mem_index = cur_batch.mem_indexes
482472
infer_state.kv_buffer = torch.empty(
483473
(cur_batch.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),

lightllm/common/basemodel/infer_struct.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ def __init__(self):
2525
self.mem_manager: MemoryManager = None
2626
self.req_manager: ReqManager = None
2727

28-
self.mem_is_contiguous = None
2928
self.mem_index = None
30-
self.mem_start = None
31-
self.mem_end = None
3229
self.kv_buffer = None
3330

3431
self.is_token_healing = False

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3131
raise Exception("need to impl")
3232

3333
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
34-
if infer_state.mem_is_contiguous:
35-
cache_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][
36-
infer_state.mem_start : infer_state.mem_end, :, :
37-
]
38-
else:
39-
cache_kv = infer_state.kv_buffer
34+
cache_kv = infer_state.kv_buffer
4035
return cache_kv
4136

4237
def _get_qkv(

0 commit comments

Comments
 (0)