@@ -515,27 +515,9 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
515515 input_embs = pre_method (cuda_input_ids , infer_state , self .pre_post_weight )
516516 input_tensors = [input_embs ]
517517
518- # prefill cuda graph 在 qwen3 vl 上的前几层由于特殊的处理,导致目前无法支持cuda graph
519- from lightllm .utils .config_utils import is_qwen3_vl
520-
521- if is_qwen3_vl ():
522- no_graph_layer_num = 3
523- else :
524- no_graph_layer_num = 0
525-
526- def no_graph_prefill_func (input_tensors , infer_state ):
527- _input_embs = input_tensors [0 ]
528- for i in range (no_graph_layer_num ):
529- layer = self .layers_infer [i ]
530- layer_method = (layer .context_forward , layer .tpsp_context_forward )[run_mode_index ]
531- _input_embs = layer_method (_input_embs , infer_state , self .trans_layers_weight [i ])
532- return [_input_embs ]
533-
534- input_tensors = no_graph_prefill_func (input_tensors = input_tensors , infer_state = infer_state )
535-
536- def graph_prefill_func (input_tensors , infer_state ):
518+ def prefill_func (input_tensors , infer_state ):
537519 _input_embs = input_tensors [0 ]
538- for i in range (no_graph_layer_num , self .layers_num ):
520+ for i in range (self .layers_num ):
539521 layer = self .layers_infer [i ]
540522 layer_method = (layer .context_forward , layer .tpsp_context_forward )[run_mode_index ]
541523 _input_embs = layer_method (_input_embs , infer_state , self .trans_layers_weight [i ])
@@ -549,7 +531,7 @@ def graph_prefill_func(input_tensors, infer_state):
549531 )
550532 if self .prefill_graph .need_capture (handle_token_num = finded_handle_token_num ):
551533 output_tensors : List [torch .Tensor ] = self .prefill_graph .capture_prefill (
552- prefill_func = graph_prefill_func ,
534+ prefill_func = prefill_func ,
553535 input_tensors = input_tensors ,
554536 infer_state = infer_state ,
555537 )
@@ -560,7 +542,7 @@ def graph_prefill_func(input_tensors, infer_state):
560542
561543 else :
562544 g_cache_manager .cache_env_in ()
563- output_tensors : List [torch .Tensor ] = graph_prefill_func (input_tensors , infer_state )
545+ output_tensors : List [torch .Tensor ] = prefill_func (input_tensors , infer_state )
564546 g_cache_manager .cache_env_out ()
565547
566548 input_embs = output_tensors [0 ]
0 commit comments