@@ -26,7 +26,7 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
2626 img_start_loc = 0
2727
2828 infer_state .input_ids = input_ids
29- img_start_token_ids = []
29+ infer_state . img_start_token_ids = []
3030 img_token_lens = []
3131 img_start_locs = []
3232
@@ -39,10 +39,10 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
3939 for batch_id , p in enumerate (infer_state .multimodal_params ):
4040 for img in p ["images" ] + p ["audios" ]:
4141 # skip the same image
42- if img ["token_id" ] in img_start_token_ids or img ["_prefill_" ] is False :
42+ if img ["token_id" ] in infer_state . img_start_token_ids or img ["_prefill_" ] is False :
4343 continue
44- infer_state .image_num_need_deepstack += 1
4544
45+ infer_state .image_num_need_deepstack += 1
4646 # all_img_embed_df的shape是
4747 # image_embed(token_num, hidden_dim) + deepstack(token_num*layer_num, hidden_dim)
4848 all_img_embed_df = bytes2tensor (read_shm (get_shm_name_embed (img ["uuid" ])))
@@ -58,7 +58,7 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
5858 per_image_deepstack .append (all_img_embed_df [start :end ])
5959
6060 infer_state .deepstack_features .append (per_image_deepstack )
61- img_start_token_ids .append (img ["token_id" ])
61+ infer_state . img_start_token_ids .append (img ["token_id" ])
6262 img_token_lens .append (img ["token_num" ])
6363 img_start_locs .append (img_start_loc )
6464 img_start_loc += img ["token_num" ]
@@ -74,7 +74,7 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
7474 )
7575 # each tp will fill the img embeds, should divide by world_size
7676 img_weight = img_weight / self .tp_world_size_
77- infer_state . img_start_token_ids = torch .Tensor (img_start_token_ids ).to (device = device , dtype = torch .long )
77+ img_start_token_ids = torch .Tensor (infer_state . img_start_token_ids ).to (device = device , dtype = torch .long )
7878 infer_state .img_token_lens = torch .Tensor (img_token_lens ).to (device = device , dtype = torch .long )
7979 infer_state .img_start_locs = torch .Tensor (img_start_locs ).to (device = device , dtype = torch .long )
8080
@@ -84,7 +84,7 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
8484 layer_weight .wte_weight_ ,
8585 img_weight ,
8686 infer_state .img_token_lens ,
87- infer_state . img_start_token_ids ,
87+ img_start_token_ids ,
8888 infer_state .img_start_locs ,
8989 self .vob_start_id_ ,
9090 self .vob_end_id_ ,
0 commit comments