77from lightllm .models .vit .layer_weights .pre_and_post_layer_weight import ViTPreAndPostLayerWeight
88from lightllm .models .vit .layer_weights .transformer_layer_weight import ViTTransformerLayerWeight
99from lightllm .models .vit .layer_weights .hf_load_utils import load_hf_weights
10+ from lightllm .server .multimodal_params import MultimodalParams , ImageItem
1011from lightllm .common .build_utils import repair_config
1112from lightllm .utils .log_utils import init_logger
1213from lightllm .models .vit import get_load_image_func
@@ -135,21 +136,20 @@ def forward(self, pixel_values):
135136 return input_embs
136137
137138 @torch .no_grad ()
138- def encode (self , image_uuids : List , max_num_list : List ):
139+ def encode (self , images : List [ ImageItem ] ):
139140 img_tensors = []
140141 valid_ids = []
141142 valid_id = 0
142143 uuids = []
143- for i , url in enumerate (image_uuids ):
144- if isinstance (url , int ):
145- uuids .append (url )
146- image_data = read_shm (get_shm_name_data (url ))
144+ for i , img in enumerate (images ):
145+ if isinstance (img , ImageItem ):
146+ uuids .append (img . uuid )
147+ image_data = read_shm (get_shm_name_data (img . uuid ))
147148 image_data = Image .open (BytesIO (image_data ))
148- max_num = max_num_list [i ]
149- t = self .load_image_func (image_data , max_num = max_num )
149+ t = self .load_image_func (image_data , max_num = img .extra_params ["image_patch_max_num" ])
150150 img_tensors .append (t )
151151 else :
152- raise Exception ("Unsupport input types: {} for {}" .format (type (url ), url ))
152+ raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
153153
154154 cur_num = img_tensors [- 1 ].shape [0 ]
155155 valid_ids .append ([valid_id , valid_id + cur_num ])
@@ -160,7 +160,6 @@ def encode(self, image_uuids: List, max_num_list: List):
160160
161161 imgs = torch .cat (img_tensors , dim = 0 )
162162 pixel_values = imgs .cuda ().to (dtype = self .data_type )
163- print (pixel_values .shape , pixel_values .dtype )
164163 all_img_embeds = self .forward (pixel_values )
165164 return all_img_embeds , uuids , valid_ids
166165
0 commit comments