33
44from lightllm .models .llama .layer_weights .pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
55from lightllm .models .llama .infer_struct import LlamaInferStateInfo
6+ from lightllm .models .qwen3_vl .infer_struct import Qwen3VLInferStateInfo
67
78from lightllm .server .embed_cache .utils import (
89 bytes2tensor ,
@@ -20,13 +21,14 @@ def __init__(self, network_config, mode):
2021 super ().__init__ (network_config , mode )
2122 return
2223
23- def context_forward (self , input_ids , infer_state : LlamaInferStateInfo , layer_weight : LlamaPreAndPostLayerWeight ):
24-
24+ def context_forward (self , input_ids , infer_state : Qwen3VLInferStateInfo , layer_weight : LlamaPreAndPostLayerWeight ):
2525 img_weight = []
26- img_start_token_ids = []
27- img_token_lens = []
2826 img_start_loc = 0
29- img_start_locs = []
27+
28+ infer_state .input_ids = input_ids
29+ infer_state .img_start_token_ids = []
30+ infer_state .img_token_lens = []
31+ infer_state .img_start_locs = []
3032
3133 device = layer_weight .wte_weight_ .device
3234 dtype = layer_weight .wte_weight_ .dtype
@@ -37,12 +39,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
3739 for batch_id , p in enumerate (infer_state .multimodal_params ):
3840 for img in p ["images" ] + p ["audios" ]:
3941 # skip the same image
40- if img ["token_id" ] in img_start_token_ids or img ["_prefill_" ] is False :
42+ if img ["token_id" ] in infer_state . img_start_token_ids or img ["_prefill_" ] is False :
4143 continue
42- pos = (input_ids == img ["token_id" ]).nonzero (as_tuple = True )
43- if pos [0 ].numel () == 0 :
44- continue
45- # pull the img_embeds by uid from shm
44+
4645 all_img_embed_df = bytes2tensor (read_shm (get_shm_name_embed (img ["uuid" ])))
4746 per_image_deepstack = []
4847
@@ -55,12 +54,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5554 per_image_deepstack .append (all_img_embed_df [start :end ])
5655
5756 infer_state .deepstack_features .append (per_image_deepstack )
58- img_insert_locs = int (pos [0 ][0 ])
59- infer_state .img_first_token_locs .append (img_insert_locs )
60- infer_state .img_last_token_locs .append (img_insert_locs + img ["token_num" ])
61- img_start_token_ids .append (img ["token_id" ])
62- img_token_lens .append (img ["token_num" ])
63- img_start_locs .append (img_start_loc )
57+ infer_state .img_start_token_ids .append (img ["token_id" ])
58+ infer_state .img_token_lens .append (img ["token_num" ])
59+ infer_state .img_start_locs .append (img_start_loc )
6460 img_start_loc += img ["token_num" ]
6561 out = torch .zeros ((len (input_ids ), hidden_size ), dtype = dtype , device = device )
6662
@@ -74,9 +70,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
7470 )
7571 # each tp will fill the img embeds, should divide by world_size
7672 img_weight = img_weight / self .tp_world_size_
77- img_start_token_ids = torch .Tensor (img_start_token_ids ).to (device = device , dtype = torch .long )
78- img_token_lens = torch .Tensor (img_token_lens ).to (device = device , dtype = torch .long )
79- img_start_locs = torch .Tensor (img_start_locs ).to (device = device , dtype = torch .long )
73+ img_start_token_ids = torch .Tensor (infer_state . img_start_token_ids ).to (device = device , dtype = torch .long )
74+ img_token_lens = torch .Tensor (infer_state . img_token_lens ).to (device = device , dtype = torch .long )
75+ img_start_locs = torch .Tensor (infer_state . img_start_locs ).to (device = device , dtype = torch .long )
8076
8177 multimodal_emb (
8278 out ,
0 commit comments