Skip to content

Commit 1bbf90b

Browse files
committed
fix
1 parent 1388391 commit 1bbf90b

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features
2020
from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
2121
from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward
22+
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor
2223

2324

2425
class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer):
@@ -77,9 +78,42 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
7778
if self.tp_world_size_ > 1:
7879
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
7980
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
80-
apply_deepstack_features(
81+
self._apply_deepstack_features_wrapper_run(
8182
input_embeddings=input_embdings,
8283
infer_state=infer_state,
8384
layer_num=self.layer_num_,
8485
)
8586
return input_embdings
87+
88+
def _apply_deepstack_features_wrapper_run(
89+
self,
90+
input_embeddings: torch.Tensor,
91+
infer_state: InferStateInfo,
92+
layer_num: int,
93+
):
94+
if torch.cuda.is_current_stream_capturing():
95+
input_embeddings = input_embeddings.contiguous()
96+
_input_embeddings = tensor_to_no_ref_tensor(input_embeddings)
97+
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
98+
pre_capture_graph.__exit__(None, None, None)
99+
100+
infer_state.prefill_cuda_graph_create_graph_obj()
101+
infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__()
102+
103+
def apply_func(new_infer_state: InferStateInfo):
104+
apply_deepstack_features(
105+
input_embeddings=_input_embeddings,
106+
infer_state=new_infer_state,
107+
layer_num=layer_num,
108+
)
109+
return
110+
111+
infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=apply_func, after_graph=pre_capture_graph)
112+
else:
113+
apply_deepstack_features(
114+
input_embeddings=input_embeddings,
115+
infer_state=infer_state,
116+
layer_num=layer_num,
117+
)
118+
119+
return

0 commit comments

Comments
 (0)