@@ -180,3 +180,91 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
180180 "label" : label_str ,
181181 "confidence" : score ,
182182 }
183+
184+ @torch .inference_mode ()
185+ def predict_batch (
186+ self , images : List [Union [Image .Image , np .ndarray ]]
187+ ) -> List [List [dict ]]:
188+ """
189+ Batch prediction for multiple images - more efficient than calling predict() multiple times.
190+
191+ Parameters
192+ ----------
193+ images : List[Union[Image.Image, np.ndarray]]
194+ List of images to process in a single batch
195+
196+ Returns
197+ -------
198+ List[List[dict]]
199+ List of prediction lists, one per input image. Each prediction dict contains:
200+ "label", "confidence", "l", "t", "r", "b"
201+ """
202+ if not images :
203+ return []
204+
205+ # Convert all images to RGB PIL format
206+ pil_images = []
207+ for img in images :
208+ if isinstance (img , Image .Image ):
209+ pil_images .append (img .convert ("RGB" ))
210+ elif isinstance (img , np .ndarray ):
211+ pil_images .append (Image .fromarray (img ).convert ("RGB" ))
212+ else :
213+ raise TypeError ("Not supported input image format" )
214+
215+ # Get target sizes for all images
216+ target_sizes = torch .tensor ([img .size [::- 1 ] for img in pil_images ])
217+
218+ # Process all images in a single batch
219+ inputs = self ._image_processor (images = pil_images , return_tensors = "pt" ).to (
220+ self ._device
221+ )
222+ outputs = self ._model (** inputs )
223+
224+ # Post-process all results at once
225+ results_list : List [Dict [str , Tensor ]] = (
226+ self ._image_processor .post_process_object_detection (
227+ outputs ,
228+ target_sizes = target_sizes ,
229+ threshold = self ._threshold ,
230+ )
231+ )
232+
233+ # Convert results to standard format for each image
234+ all_predictions = []
235+
236+ for img , results in zip (pil_images , results_list ):
237+ w , h = img .size
238+ predictions = []
239+
240+ for score , label_id , box in zip (
241+ results ["scores" ], results ["labels" ], results ["boxes" ]
242+ ):
243+ score = float (score .item ())
244+ label_id = int (label_id .item ()) + self ._label_offset
245+ label_str = self ._classes_map [label_id ]
246+
247+ # Filter out blacklisted classes
248+ if label_str in self ._black_classes :
249+ continue
250+
251+ bbox_float = [float (b .item ()) for b in box ]
252+ l = min (w , max (0 , bbox_float [0 ]))
253+ t = min (h , max (0 , bbox_float [1 ]))
254+ r = min (w , max (0 , bbox_float [2 ]))
255+ b = min (h , max (0 , bbox_float [3 ]))
256+
257+ predictions .append (
258+ {
259+ "l" : l ,
260+ "t" : t ,
261+ "r" : r ,
262+ "b" : b ,
263+ "label" : label_str ,
264+ "confidence" : score ,
265+ }
266+ )
267+
268+ all_predictions .append (predictions )
269+
270+ return all_predictions
0 commit comments