Skip to content

Commit 1388391

Browse files
committed
fix
1 parent 1ebdd7b commit 1388391

File tree

2 files changed

+4
-35
lines changed

2 files changed

+4
-35
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -515,27 +515,9 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
515515
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
516516
input_tensors = [input_embs]
517517

518-
# prefill cuda graph 在 qwen3 vl 上的前几层由于特殊的处理,导致目前无法支持cuda graph
519-
from lightllm.utils.config_utils import is_qwen3_vl
520-
521-
if is_qwen3_vl():
522-
no_graph_layer_num = 3
523-
else:
524-
no_graph_layer_num = 0
525-
526-
def no_graph_prefill_func(input_tensors, infer_state):
527-
_input_embs = input_tensors[0]
528-
for i in range(no_graph_layer_num):
529-
layer = self.layers_infer[i]
530-
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
531-
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
532-
return [_input_embs]
533-
534-
input_tensors = no_graph_prefill_func(input_tensors=input_tensors, infer_state=infer_state)
535-
536-
def graph_prefill_func(input_tensors, infer_state):
518+
def prefill_func(input_tensors, infer_state):
537519
_input_embs = input_tensors[0]
538-
for i in range(no_graph_layer_num, self.layers_num):
520+
for i in range(self.layers_num):
539521
layer = self.layers_infer[i]
540522
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
541523
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
@@ -549,7 +531,7 @@ def graph_prefill_func(input_tensors, infer_state):
549531
)
550532
if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num):
551533
output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill(
552-
prefill_func=graph_prefill_func,
534+
prefill_func=prefill_func,
553535
input_tensors=input_tensors,
554536
infer_state=infer_state,
555537
)
@@ -560,7 +542,7 @@ def graph_prefill_func(input_tensors, infer_state):
560542

561543
else:
562544
g_cache_manager.cache_env_in()
563-
output_tensors: List[torch.Tensor] = graph_prefill_func(input_tensors, infer_state)
545+
output_tensors: List[torch.Tensor] = prefill_func(input_tensors, infer_state)
564546
g_cache_manager.cache_env_out()
565547

566548
input_embs = output_tensors[0]

lightllm/utils/config_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,3 @@ def get_fixed_kv_len():
132132
return len(model_cfg["prompt_cache_token_ids"])
133133
else:
134134
return 0
135-
136-
137-
@lru_cache(maxsize=None)
138-
def is_qwen3_vl():
139-
from lightllm.utils.llm_utils import get_llm_model_class
140-
from lightllm.models import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel
141-
142-
model_class = get_llm_model_class()
143-
144-
if model_class in [Qwen3VLTpPartModel, Qwen3VLMOETpPartModel]:
145-
return True
146-
else:
147-
return False

0 commit comments

Comments
 (0)