Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 136 additions & 56 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import torch
import torch.nn.functional as F
from typing import final
from typing import final, List
from tqdm import tqdm

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
Expand All @@ -19,6 +19,7 @@
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.common.basemodel.cuda_graph import CudaGraph
from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.log_utils import init_logger
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(self, kvargs):
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]
self.prefill_graph: PrefillCudaGraph = None

self._init_config()
self._verify_must()
Expand All @@ -115,6 +117,7 @@ def __init__(self, kvargs):
# wait必须在init cudagraph 之前,避免错误捕获
self._wait_other_modules_ready()
self._init_cudagraph()
self._init_prefill_cuda_graph()
self._check_max_len_infer()
torch.cuda.empty_cache()
set_model_init_status(True)
Expand Down Expand Up @@ -240,6 +243,18 @@ def _init_cudagraph(self):
else:
self.graph.warmup(self)

def _init_prefill_cuda_graph(self):
self.prefill_graph = (
None
if not get_env_start_args().enable_prefill_cudagraph
else PrefillCudaGraph(decode_cuda_graph=self.graph)
)
if self.prefill_graph is not None:
if get_env_start_args().enable_prefill_microbatch_overlap:
self.prefill_graph.warmup_overlap(self)
else:
self.prefill_graph.warmup(self)

def _init_custom(self):
pass

Expand Down Expand Up @@ -332,6 +347,50 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s

return new_model_input

def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle_token_num: int):
if model_input.total_token_num - model_input.prefix_total_token_num == new_handle_token_num:
return model_input

assert model_input.total_token_num - model_input.prefix_total_token_num < new_handle_token_num

padded_token_num = new_handle_token_num - (model_input.total_token_num - model_input.prefix_total_token_num)
new_model_input = copy.copy(model_input)
new_model_input.batch_size = model_input.batch_size + 1
new_model_input.total_token_num += padded_token_num
new_model_input.max_len_in_batch = max(padded_token_num, model_input.max_len_in_batch)
new_model_input.max_q_seq_len = max(padded_token_num, model_input.max_q_seq_len)
new_model_input.max_kv_seq_len = max(padded_token_num, model_input.max_kv_seq_len)
new_model_input.max_cache_len = max(0, model_input.max_cache_len)
new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_token_num), mode="constant", value=1)
new_model_input.mem_indexes = F.pad(
new_model_input.mem_indexes,
(0, padded_token_num),
mode="constant",
value=self.mem_manager.HOLD_TOKEN_MEMINDEX,
)
new_model_input.b_req_idx = F.pad(
new_model_input.b_req_idx, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID
)
new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0)
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num)
new_model_input.b_ready_cache_len = F.pad(new_model_input.b_ready_cache_len, (0, 1), mode="constant", value=0)
b_q_seq_len = new_model_input.b_seq_len - new_model_input.b_ready_cache_len
new_model_input.b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
# 构建新的list, 使用 append 可能会让外面使用的数组引用发生变化,导致错误。
new_model_input.b_prefill_has_output_cpu = [e for e in new_model_input.b_prefill_has_output_cpu] + [False]
new_model_input.prefix_total_token_num = model_input.prefix_total_token_num

# TODO 多模态的参数需要 pad 吗,需要check

# 特殊模型,特殊模式的特殊变量的特殊 padding
if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None:
new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch(
input=new_model_input.deepseekv3_mtp_draft_input_hiddens,
new_batch_size=padded_token_num,
)

return new_model_input

def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_batch_size: int):
padded_batch_size = model_output.logits.shape[0]
if padded_batch_size == origin_batch_size:
Expand All @@ -346,10 +405,34 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba

return new_model_output

def _create_unpad_prefill_model_output(self, model_output: ModelOutput, origin_handle_token_num: int):
handle_token_num = model_output.logits.shape[0]
if handle_token_num == origin_handle_token_num:
return model_output

new_model_output = copy.copy(model_output)
new_model_output.logits = new_model_output.logits[0:origin_handle_token_num]

# 特殊模型,特殊模式的特殊变量的特殊 unpad
if new_model_output.deepseekv3_mtp_main_output_hiddens is not None:
_hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens
new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_handle_token_num]

return new_model_output

def _prefill(
self,
model_input: ModelInput,
):
handle_token_num = model_input.total_token_num - model_input.prefix_total_token_num
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
handle_token_num=handle_token_num
)
model_input = self._create_padded_prefill_model_input(
model_input=model_input, new_handle_token_num=finded_handle_token_num
)

infer_state = self._create_inferstate(model_input)
init_req_to_token_indexes(
req_to_token_indexs=self.req_manager.req_to_token_indexs,
Expand All @@ -365,6 +448,7 @@ def _prefill(

infer_state.init_some_extra_state(self, model_input.input_ids)
model_output = self._context_forward(model_input.input_ids, infer_state)
model_output = self._create_unpad_prefill_model_output(model_output, origin_handle_token_num=handle_token_num)
model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
return model_output

Expand Down Expand Up @@ -419,22 +503,45 @@ def _decode(
@final
def _context_forward(self, input_ids, infer_state: InferStateInfo):
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
g_cache_manager.cache_env_in()
cuda_input_ids = input_ids

pre_method = (self.pre_infer.context_forward, self.pre_infer.tpsp_context_forward)[run_mode_index]
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
input_tensors = [input_embs]

for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
def prefill_func(input_tensors, infer_state):
_input_embs = input_tensors[0]
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
return [_input_embs]

post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
handle_token_num = input_ids.shape[0]

g_cache_manager.cache_env_out()
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
handle_token_num=handle_token_num
)
if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num):
output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill(
prefill_func=prefill_func,
input_tensors=input_tensors,
infer_state=infer_state,
)
else:
output_tensors: List[torch.Tensor] = self.prefill_graph.replay(
input_tensors=input_tensors, infer_state=infer_state
)
Comment on lines +528 to +541
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's redundant logic here for checking if a CUDA graph can be used. The calling function _prefill already performs these checks (can_run, find_closest_graph_handle_token_num) and pads the input accordingly. This block repeats the same checks on the already-padded input.

This duplicated logic can be simplified. The decision to use a graph and the specific graph to use should be determined once in _prefill and then passed to _context_forward, for example via the infer_state object. This would make the code cleaner and easier to maintain.

A potential refactoring could look like this:

In _prefill:

# ...
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
    finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
        handle_token_num=handle_token_num
    )
    model_input = self._create_padded_prefill_model_input(
        model_input=model_input, new_handle_token_num=finded_handle_token_num
    )
    infer_state = self._create_inferstate(model_input)
    infer_state.use_prefill_graph = True
    infer_state.prefill_graph_handle_token_num = finded_handle_token_num
# ...

In _context_forward:

# ...
if getattr(infer_state, 'use_prefill_graph', False):
    handle_token_num = infer_state.prefill_graph_handle_token_num
    if self.prefill_graph.need_capture(handle_token_num=handle_token_num):
        # capture logic
    else:
        # replay logic
else:
    # non-graph logic
# ...


else:
g_cache_manager.cache_env_in()
output_tensors: List[torch.Tensor] = prefill_func(input_tensors, infer_state)
g_cache_manager.cache_env_out()

input_embs = output_tensors[0]
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
model_output = ModelOutput(logits=predict_logits)

# 特殊模型特殊模式的额外输出
Expand All @@ -449,40 +556,30 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
@final
def _token_forward(self, input_ids, infer_state: InferStateInfo):
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
g_cache_manager.cache_env_in(
is_cuda_graph=infer_state.is_cuda_graph,
cur_batch_size=infer_state.batch_size,
cuda_graph_max_batch_size=self.graph_max_batch_size,
)
cuda_input_ids = input_ids
pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index]
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
for i in range(self.layers_num):
layer = self.layers_infer[i]
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
input_embs: torch.Tensor = layer_method(input_embs, infer_state, self.trans_layers_weight[i])

post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight)

if self.is_deepseekv3_mtp_mode:
graph_out_hiddens = g_cache_manager.alloc_tensor(
input_embs.shape,
data_type=input_embs.dtype,
is_graph_out=True,
microbatch_index=infer_state.microbatch_index,
graph_out_key=520,
)
graph_out_hiddens.copy_(input_embs)

g_cache_manager.cache_env_out()
graph_out_hiddens = input_embs.contiguous()

model_output = ModelOutput(logits=predict_logits)
model_output = ModelOutput(logits=predict_logits.contiguous())

# 特殊模型特殊模式的额外输出
if self.is_deepseekv3_mtp_mode:
model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens

# 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。
if infer_state.is_cuda_graph:
model_output.to_no_ref_tensor()

return model_output

@torch.no_grad()
Expand Down Expand Up @@ -642,24 +739,19 @@ def _overlap_tpsp_context_forward(
)
g_cache_manager.cache_env_out()

model_output = ModelOutput(logits=predict_logits)
model_output1 = ModelOutput(logits=predict_logits1)
model_output = ModelOutput(logits=predict_logits.contiguous())
model_output1 = ModelOutput(logits=predict_logits1.contiguous())

if self.is_deepseekv3_mtp_mode:
model_output.deepseekv3_mtp_main_output_hiddens = input_embs
model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1
model_output.deepseekv3_mtp_main_output_hiddens = input_embs.contiguous()
model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1.contiguous()

return model_output, model_output1

@final
def _overlap_tpsp_token_forward(
self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo
):
g_cache_manager.cache_env_in(
is_cuda_graph=infer_state.is_cuda_graph,
cur_batch_size=infer_state.batch_size,
cuda_graph_max_batch_size=self.graph_max_batch_size,
)
input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward(
input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight
)
Expand All @@ -674,32 +766,20 @@ def _overlap_tpsp_token_forward(
)

if self.is_deepseekv3_mtp_mode:
graph_out_hiddens = g_cache_manager.alloc_tensor(
input_embs.shape,
data_type=input_embs.dtype,
is_graph_out=True,
microbatch_index=0,
graph_out_key=520,
)
graph_out_hiddens.copy_(input_embs)
graph_out_hiddens1 = g_cache_manager.alloc_tensor(
input_embs1.shape,
data_type=input_embs1.dtype,
is_graph_out=True,
microbatch_index=1,
graph_out_key=520,
)
graph_out_hiddens1.copy_(input_embs1)
graph_out_hiddens = input_embs.contiguous()
graph_out_hiddens1 = input_embs1.contiguous()

g_cache_manager.cache_env_out()

model_output = ModelOutput(logits=predict_logits)
model_output1 = ModelOutput(logits=predict_logits1)
model_output = ModelOutput(logits=predict_logits.contiguous())
model_output1 = ModelOutput(logits=predict_logits1.contiguous())

if self.is_deepseekv3_mtp_mode:
model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens
model_output1.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens1

if infer_state.is_cuda_graph:
model_output.to_no_ref_tensor()
model_output1.to_no_ref_tensor()

return model_output, model_output1

@final
Expand Down
6 changes: 6 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional
from typing import List
from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor


@dataclass
Expand Down Expand Up @@ -88,3 +89,8 @@ class ModelOutput:
# 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens
# 输入
deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None

def to_no_ref_tensor(self):
self.logits = tensor_to_no_ref_tensor(self.logits)
if self.deepseekv3_mtp_main_output_hiddens is not None:
self.deepseekv3_mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.deepseekv3_mtp_main_output_hiddens)
Loading