@@ -279,7 +279,7 @@ def generate_stream(self, params):
279279 max_input_tile_temp = []
280280 for image_str in message ['image' ]:
281281 pil_images .append (load_image_from_base64 (image_str ))
282- prefix += f'Image-{ global_image_cnt + 1 } : <image>\n \n '
282+ prefix += f'Image-{ global_image_cnt + 1 } : <image>\n '
283283 global_image_cnt += 1
284284 max_input_tile_temp .append (max (1 , max_input_tiles // len (message ['image' ])))
285285 if len (max_input_tile_temp ) > 0 :
@@ -291,8 +291,8 @@ def generate_stream(self, params):
291291 question , history = history [- 1 ][0 ], history [:- 1 ]
292292
293293 if global_image_cnt == 1 :
294- question = question .replace ('Image-1: <image>\n \n ' , '<image>\n ' )
295- history = [[item [0 ].replace ('Image-1: <image>\n \n ' , '<image>\n ' ), item [1 ]] for item in history ]
294+ question = question .replace ('Image-1: <image>\n ' , '<image>\n ' )
295+ history = [[item [0 ].replace ('Image-1: <image>\n ' , '<image>\n ' ), item [1 ]] for item in history ]
296296
297297 # Create a new list to store processed sublists
298298 flattened_list = []
@@ -308,7 +308,7 @@ def generate_stream(self, params):
308308
309309 old_system_message = self .model .system_message
310310 self .model .system_message = system_message
311- image_tiles = []
311+ image_tiles , num_patches_list = [], []
312312 transform = build_transform (input_size = self .image_size )
313313 if len (pil_images ) > 0 :
314314 for current_max_input_tiles , pil_image in zip (max_input_tile_list , pil_images ):
@@ -318,6 +318,7 @@ def generate_stream(self, params):
318318 use_thumbnail = self .model .config .use_thumbnail )
319319 else :
320320 tiles = [pil_image ]
321+ num_patches_list .append (len (tiles ))
321322 image_tiles += tiles
322323 pixel_values = [transform (item ) for item in image_tiles ]
323324 pixel_values = torch .stack (pixel_values ).to (self .model .device , dtype = torch .bfloat16 )
@@ -341,6 +342,7 @@ def generate_stream(self, params):
341342 thread = Thread (target = self .model .chat , kwargs = dict (
342343 tokenizer = self .tokenizer ,
343344 pixel_values = pixel_values ,
345+ num_patches_list = num_patches_list ,
344346 question = question ,
345347 history = history ,
346348 return_history = False ,
0 commit comments