|
19 | 19 | from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features |
20 | 20 | from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer |
21 | 21 | from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward |
| 22 | +from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor |
22 | 23 |
|
23 | 24 |
|
24 | 25 | class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer): |
@@ -77,9 +78,42 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la |
77 | 78 | if self.tp_world_size_ > 1: |
78 | 79 | all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) |
79 | 80 | input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) |
80 | | - apply_deepstack_features( |
| 81 | + self._apply_deepstack_features_wrapper_run( |
81 | 82 | input_embeddings=input_embdings, |
82 | 83 | infer_state=infer_state, |
83 | 84 | layer_num=self.layer_num_, |
84 | 85 | ) |
85 | 86 | return input_embdings |
| 87 | + |
| 88 | + def _apply_deepstack_features_wrapper_run( |
| 89 | + self, |
| 90 | + input_embeddings: torch.Tensor, |
| 91 | + infer_state: InferStateInfo, |
| 92 | + layer_num: int, |
| 93 | + ): |
| 94 | + if torch.cuda.is_current_stream_capturing(): |
| 95 | + input_embeddings = input_embeddings.contiguous() |
| 96 | + _input_embeddings = tensor_to_no_ref_tensor(input_embeddings) |
| 97 | + pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() |
| 98 | + pre_capture_graph.__exit__(None, None, None) |
| 99 | + |
| 100 | + infer_state.prefill_cuda_graph_create_graph_obj() |
| 101 | + infer_state.prefill_cuda_graph_get_current_capture_graph().__enter__() |
| 102 | + |
| 103 | + def apply_func(new_infer_state: InferStateInfo): |
| 104 | + apply_deepstack_features( |
| 105 | + input_embeddings=_input_embeddings, |
| 106 | + infer_state=new_infer_state, |
| 107 | + layer_num=layer_num, |
| 108 | + ) |
| 109 | + return |
| 110 | + |
| 111 | + infer_state.prefill_cuda_graph_add_cpu_runnning_func(func=apply_func, after_graph=pre_capture_graph) |
| 112 | + else: |
| 113 | + apply_deepstack_features( |
| 114 | + input_embeddings=input_embeddings, |
| 115 | + infer_state=infer_state, |
| 116 | + layer_num=layer_num, |
| 117 | + ) |
| 118 | + |
| 119 | + return |
0 commit comments