Skip to content

Commit cc2860e

Browse files
committed
clean code
1 parent 05bb49f commit cc2860e

File tree

1 file changed

+73
-70
lines changed

1 file changed

+73
-70
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
4+
import gc
45
import copy
56
import json
67
import torch
@@ -391,6 +392,71 @@ def _decode(
391392

392393
return model_output
393394

395+
def _build_prefill_model_input(
396+
self, input_len: int, random_token: bool = False, include_special: bool = False
397+
) -> ModelInput:
398+
dummy_input_ids = (
399+
torch.randint(0, 10000, (input_len,), dtype=torch.int32, device="cuda")
400+
if random_token
401+
else torch.ones(input_len, dtype=torch.int32, device="cuda")
402+
)
403+
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
404+
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
405+
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
406+
b_seq_len[:] = input_len
407+
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
408+
total_token_num = input_len
409+
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
410+
411+
special_kwargs = {}
412+
if include_special:
413+
special_kwargs.update(self._gen_special_model_input(total_token_num))
414+
415+
model_input = ModelInput(
416+
batch_size=1,
417+
total_token_num=total_token_num,
418+
max_len_in_batch=input_len,
419+
input_ids=dummy_input_ids,
420+
mem_indexes=mem_indexes,
421+
b_req_idx=b_req_idx,
422+
b_seq_len=b_seq_len,
423+
b_mtp_index=b_mtp_index,
424+
is_prefill=True,
425+
b_ready_cache_len=b_ready_cache_len,
426+
multimodal_params=[],
427+
**special_kwargs,
428+
)
429+
return model_input
430+
431+
def _build_padded_prefill_hold_model_input(self, prefill_input_len: int, batch_size: int) -> ModelInput:
432+
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
433+
b_req_idx = torch.tensor(
434+
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
435+
)
436+
mem_indexes = torch.tensor(
437+
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda"
438+
)
439+
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
440+
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
441+
total_token_num = prefill_input_len * batch_size
442+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
443+
444+
model_input = ModelInput(
445+
batch_size=batch_size,
446+
total_token_num=total_token_num,
447+
max_len_in_batch=prefill_input_len,
448+
input_ids=dummy_input_ids,
449+
mem_indexes=mem_indexes,
450+
b_req_idx=b_req_idx,
451+
b_mtp_index=b_mtp_index,
452+
b_seq_len=b_seq_len,
453+
b_ready_cache_len=b_ready_cache_len,
454+
is_prefill=True,
455+
multimodal_params=[],
456+
**self._gen_special_model_input(total_token_num),
457+
)
458+
return model_input
459+
394460
@final
395461
def _context_forward(self, input_ids, infer_state: InferStateInfo):
396462
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
@@ -680,25 +746,8 @@ def _check_max_len_infer(self):
680746
# 模拟最大长度进行 prefill,观察是否出现 OOM
681747
try:
682748
logger.info("begin check max_len infer")
683-
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
684-
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
685-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
686-
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
687-
b_seq_len[:] = self.batch_max_tokens
688-
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
689-
total_token_num = self.batch_max_tokens
690-
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
691-
model_input = ModelInput(
692-
batch_size=1,
693-
total_token_num=total_token_num,
694-
max_len_in_batch=self.batch_max_tokens,
695-
input_ids=dummy_input_ids,
696-
mem_indexes=mem_indexes,
697-
b_req_idx=b_req_idx,
698-
b_seq_len=b_seq_len,
699-
b_mtp_index=b_mtp_index,
700-
is_prefill=True,
701-
b_ready_cache_len=b_ready_cache_len,
749+
model_input = self._build_prefill_model_input(
750+
self.batch_max_tokens, random_token=False, include_special=False
702751
)
703752
model_output = self.forward(
704753
model_input,
@@ -752,40 +801,21 @@ def _autotune_warmup(self):
752801
for input_len in warmup_lengths:
753802
try:
754803
logger.info(f"autotune warmup for length {input_len}")
755-
dummy_input_ids = torch.randint(0, 10000, (input_len,), dtype=torch.int32, device="cuda")
756-
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
757-
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
758-
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
759-
b_seq_len[:] = input_len
760-
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
761-
total_token_num = input_len
762-
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
763-
model_input = ModelInput(
764-
batch_size=1,
765-
total_token_num=total_token_num,
766-
max_len_in_batch=input_len,
767-
input_ids=dummy_input_ids,
768-
mem_indexes=mem_indexes,
769-
b_req_idx=b_req_idx,
770-
b_seq_len=b_seq_len,
771-
b_mtp_index=b_mtp_index,
772-
is_prefill=True,
773-
b_ready_cache_len=b_ready_cache_len,
774-
multimodal_params=[],
775-
**self._gen_special_model_input(total_token_num),
776-
)
804+
model_input = self._build_prefill_model_input(input_len, random_token=True, include_special=True)
777805
model_output = self.forward(
778806
model_input,
779807
)
780808
del model_output
781809
self.req_manager.free_all()
782810
self.mem_manager.free_all()
811+
gc.collect()
783812
torch.cuda.empty_cache()
784813
logger.info(f"autotune warmup for length {input_len} ok")
785814
except Exception as e:
786815
logger.warning(f"autotune warmup for length {input_len} failed: {str(e)}")
787816
self.req_manager.free_all()
788817
self.mem_manager.free_all()
818+
gc.collect()
789819
torch.cuda.empty_cache()
790820
self.layers_num = layer_num_bak
791821
torch.distributed.barrier()
@@ -803,39 +833,12 @@ def _init_padded_req(self):
803833
# prefill init padding req.
804834
prefill_input_len = 1
805835
batch_size = 1
806-
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
807-
b_req_idx = torch.tensor(
808-
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
809-
)
810-
mem_indexes = torch.tensor(
811-
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda"
812-
)
813-
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
814-
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
815-
total_token_num = prefill_input_len * batch_size
816-
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
817-
model_input = ModelInput(
818-
batch_size=batch_size,
819-
total_token_num=total_token_num,
820-
max_len_in_batch=prefill_input_len,
821-
input_ids=dummy_input_ids,
822-
mem_indexes=mem_indexes,
823-
b_req_idx=b_req_idx,
824-
b_mtp_index=b_mtp_index,
825-
b_seq_len=b_seq_len,
826-
b_ready_cache_len=b_ready_cache_len,
827-
is_prefill=True,
828-
multimodal_params=[],
829-
**self._gen_special_model_input(total_token_num),
836+
model_input = self._build_padded_prefill_hold_model_input(
837+
prefill_input_len=prefill_input_len, batch_size=batch_size
830838
)
831839

832840
model_output: ModelOutput = self.forward(model_input)
833841
del model_input
834-
del dummy_input_ids
835-
del b_req_idx
836-
del mem_indexes
837-
del b_seq_len
838-
del b_ready_cache_len
839842
del model_output
840843
torch.cuda.empty_cache()
841844
return

0 commit comments

Comments
 (0)