Skip to content

Commit 7729439

Browse files
authored
Code clean up (#860)
1 parent 8b3a55a commit 7729439

File tree

6 files changed

+19
-40
lines changed

6 files changed

+19
-40
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,10 @@ def _prefill(
302302
infer_state.mem_manager = self.mem_manager
303303
infer_state.req_manager = self.req_manager
304304

305-
infer_state.mem_is_contiguous = False
306305
infer_state.mem_index = mem_indexes
307-
infer_state.kv_buffer = torch.empty(
306+
infer_state.kv_buffer_shapedtype = (
308307
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
309-
dtype=self.data_type,
310-
device="cuda",
308+
self.data_type,
311309
)
312310
infer_state.dist_group = dist_group_manager.get_default_group()
313311

@@ -351,14 +349,10 @@ def _decode(
351349
infer_state.mem_manager = self.mem_manager
352350
infer_state.req_manager = self.req_manager
353351

354-
# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
355-
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
356-
infer_state.mem_is_contiguous = False
357352
infer_state.mem_index = mem_indexes
358-
infer_state.kv_buffer = torch.empty(
353+
infer_state.kv_buffer_shapedtype = (
359354
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
360-
dtype=self.data_type,
361-
device="cuda",
355+
self.data_type,
362356
)
363357
infer_state.dist_group = dist_group_manager.get_default_group()
364358
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
@@ -398,14 +392,10 @@ def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
398392
infer_state.mem_manager = self.mem_manager
399393
infer_state.req_manager = self.req_manager
400394

401-
# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
402-
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
403-
infer_state.mem_is_contiguous = False
404395
infer_state.mem_index = cur_batch.mem_indexes
405-
infer_state.kv_buffer = torch.empty(
396+
infer_state.kv_buffer_shapedtype = (
406397
(cur_batch.batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
407-
dtype=self.data_type,
408-
device="cuda",
398+
self.data_type,
409399
)
410400
infer_state.dist_group = dist_group_manager.get_group(batch_index)
411401
copy_kv_index_to_req(
@@ -475,14 +465,10 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
475465
infer_state.mem_manager = self.mem_manager
476466
infer_state.req_manager = self.req_manager
477467

478-
# 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
479-
# 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
480-
infer_state.mem_is_contiguous = False
481468
infer_state.mem_index = cur_batch.mem_indexes
482-
infer_state.kv_buffer = torch.empty(
469+
infer_state.kv_buffer_shapedtype = (
483470
(cur_batch.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
484-
dtype=self.data_type,
485-
device="cuda",
471+
self.data_type,
486472
)
487473
infer_state.dist_group = dist_group_manager.get_group(batch_index)
488474
init_req_to_token_indexes(

lightllm/common/basemodel/infer_struct.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,8 @@ def __init__(self):
2525
self.mem_manager: MemoryManager = None
2626
self.req_manager: ReqManager = None
2727

28-
self.mem_is_contiguous = None
2928
self.mem_index = None
30-
self.mem_start = None
31-
self.mem_end = None
32-
self.kv_buffer = None
29+
self.kv_buffer_shapedtype = None
3330

3431
self.is_token_healing = False
3532
self.return_all_prompt_logics = False

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3131
raise Exception("need to impl")
3232

3333
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
34-
if infer_state.mem_is_contiguous:
35-
cache_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][
36-
infer_state.mem_start : infer_state.mem_end, :, :
37-
]
38-
else:
39-
cache_kv = infer_state.kv_buffer
34+
cache_kv = self.alloc_tensor(
35+
shape=infer_state.kv_buffer_shapedtype[0],
36+
dtype=infer_state.kv_buffer_shapedtype[1],
37+
device="cuda",
38+
is_graph_out=False,
39+
microbatch_index=infer_state.microbatch_index,
40+
)
4041
return cache_kv
4142

4243
def _get_qkv(
@@ -51,9 +52,8 @@ def _tpsp_get_qkv(
5152

5253
def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight):
5354
mem_manager = infer_state.mem_manager
54-
if not infer_state.mem_is_contiguous:
55-
self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager)
56-
return
55+
self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager)
56+
return
5757

5858
def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager):
5959
destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_])

lightllm/models/deepseek2/infer_struct.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
1414
super().init_some_extra_state(model, input_ids)
1515
if not self.is_prefill:
1616
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
17-
self.total_token_num_tensor = torch.sum(self.b_seq_len)
1817

1918
if self.is_prefill:
2019
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len

test/kernel/tuning/deepseekv2_gqa_decode_tuning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def test_decode_attentions(
5050
).cuda()
5151
infer_state.b_req_idx = torch.arange(0, infer_state.batch_size, step=1, dtype=torch.int32).cuda()
5252
infer_state.b_seq_len = torch.full((infer_state.batch_size,), fill_value=test_seq_len, dtype=torch.int32).cuda()
53-
infer_state.total_token_num_tensor = torch.sum(infer_state.b_seq_len)
5453

5554
input_tuples = []
5655
for _ in range(test_count):

test/kernel/tuning/llama_gqa_decode_vsm_tuning.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ def test_decode_attentions(
5151
).cuda()
5252
state.b_req_idx = torch.arange(0, state.batch_size, step=1, dtype=torch.int32).cuda()
5353
state.b_seq_len = torch.full((state.batch_size,), fill_value=test_seq_len, dtype=torch.int32).cuda()
54-
total_token_num_tensor = torch.tensor([state.batch_size * test_seq_len], dtype=torch.int32, device="cuda")
55-
state.total_token_num = total_token_num_tensor
5654

5755
args = []
5856
q_head_dim = q_shape[2]
@@ -63,7 +61,7 @@ def test_decode_attentions(
6361
state.q_head_dim = q_head_dim
6462
state.kv_head_num = kv_head_num
6563
state.softmax_scale = 1 / (q_head_dim ** 0.5)
66-
state.total_token_num = total_token_num_tensor
64+
state.total_token_num = state.batch_size * test_seq_len
6765

6866
infer_state = state
6967
for _ in range(test_count):

0 commit comments

Comments
 (0)