Skip to content

Commit 9352fe8

Browse files
committed
qwen3_vl support prefill cuda graph feature
1 parent 60e5d01 commit 9352fe8

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,27 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
515515
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
516516
input_tensors = [input_embs]
517517

518-
def prefill_func(input_tensors, infer_state):
518+
# prefill cuda graph 在 qwen3 vl 上的前几层由于特殊的处理,导致目前无法支持cuda graph
519+
from lightllm.utils.config_utils import is_qwen3_vl
520+
521+
if is_qwen3_vl():
522+
no_graph_layer_num = 3
523+
else:
524+
no_graph_layer_num = 0
525+
526+
def no_graph_prefill_func(input_tensors, infer_state):
527+
_input_embs = input_tensors[0]
528+
for i in range(no_graph_layer_num):
529+
layer = self.layers_infer[i]
530+
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
531+
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
532+
return [_input_embs]
533+
534+
input_tensors = no_graph_prefill_func(input_tensors=input_tensors, infer_state=infer_state)
535+
536+
def graph_prefill_func(input_tensors, infer_state):
519537
_input_embs = input_tensors[0]
520-
for i in range(self.layers_num):
538+
for i in range(no_graph_layer_num, self.layers_num):
521539
layer = self.layers_infer[i]
522540
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
523541
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
@@ -531,7 +549,7 @@ def prefill_func(input_tensors, infer_state):
531549
)
532550
if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num):
533551
output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill(
534-
prefill_func=prefill_func,
552+
prefill_func=graph_prefill_func,
535553
input_tensors=input_tensors,
536554
infer_state=infer_state,
537555
)
@@ -542,7 +560,8 @@ def prefill_func(input_tensors, infer_state):
542560

543561
else:
544562
g_cache_manager.cache_env_in()
545-
output_tensors: List[torch.Tensor] = prefill_func(input_tensors, infer_state)
563+
input_tensors = no_graph_prefill_func(input_tensors=input_tensors, infer_state=infer_state)
564+
output_tensors: List[torch.Tensor] = graph_prefill_func(input_tensors, infer_state)
546565
g_cache_manager.cache_env_out()
547566

548567
input_embs = output_tensors[0]

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
7070
input1 = None
7171
self._post_cache_kv(cache_kv, infer_state, layer_weight)
7272

73-
o = self.__context_attention_wrapper_run(
73+
o = self._context_attention_wrapper_run(
7474
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
7575
)
7676

@@ -116,7 +116,7 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
116116
input1 = None
117117
self._post_cache_kv(cache_kv, infer_state, layer_weight)
118118

119-
o = self.__context_attention_wrapper_run(
119+
o = self._context_attention_wrapper_run(
120120
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
121121
)
122122

@@ -148,7 +148,7 @@ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferSta
148148
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
149149
return input_embdings
150150

151-
def __context_attention_wrapper_run(
151+
def _context_attention_wrapper_run(
152152
self, q: torch.Tensor, cache_kv: torch.Tensor, infer_state: InferStateInfo, layer_weight
153153
) -> torch.Tensor:
154154
if torch.cuda.is_current_stream_capturing():

lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
6363
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
6464
input1 = None
6565
self._post_cache_kv(cache_kv, infer_state, layer_weight)
66-
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
66+
o = self._context_attention_wrapper_run(q, cache_kv, infer_state, layer_weight)
6767
q = None
6868
o = self._get_o(o, infer_state, layer_weight)
6969
if self.tp_world_size_ > 1:

lightllm/utils/config_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,16 @@ def get_fixed_kv_len():
132132
return len(model_cfg["prompt_cache_token_ids"])
133133
else:
134134
return 0
135+
136+
137+
@lru_cache(maxsize=None)
138+
def is_qwen3_vl():
139+
from lightllm.utils.llm_utils import get_llm_model_class
140+
from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel
141+
142+
model_class = get_llm_model_class()
143+
144+
if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel]:
145+
return True
146+
else:
147+
return False

0 commit comments

Comments
 (0)