@@ -70,6 +70,7 @@ def __init__(
7070 model = None ,
7171 memory_optimize = True ,
7272 inference_extra_args = None ,
73+ batch_inference = False ,
7374 ) -> None :
7475 if model is None :
7576 self .model = YOLO (model_path ) # Load the model from the specified path
@@ -91,6 +92,7 @@ def __init__(
9192 self .memory_optimize = memory_optimize # memory opimization option for segmentation
9293 self .class_names_dict = self .model .names # dict with human-readable class names
9394 self .inference_extra_args = inference_extra_args # dict with extra ultralytics inference parameters
95+ self .batch_inference = batch_inference
9496
9597 self .crops = self .get_crops_xy (
9698 self .image ,
@@ -100,7 +102,10 @@ def __init__(
100102 overlap_y = self .overlap_y ,
101103 show = self .show_crops ,
102104 )
103- self ._detect_objects ()
105+ if self .batch_inference :
106+ self ._detect_objects_batch ()
107+ else :
108+ self ._detect_objects ()
104109
105110 def get_crops_xy (
106111 self ,
@@ -141,6 +146,7 @@ def get_crops_xy(
141146 x_new = round ((x_steps - 1 ) * (shape_x * cross_koef_x ) + shape_x )
142147 image_innitial = image_full .copy ()
143148 image_full = cv2 .resize (image_full , (x_new , y_new ))
149+ batch_of_crops = []
144150
145151 if show :
146152 plt .figure (figsize = [x_steps * 0.9 , y_steps * 0.9 ])
@@ -176,12 +182,17 @@ def get_crops_xy(
176182 x_start = x_start ,
177183 y_start = y_start ,
178184 ))
185+ if self .batch_inference :
186+ batch_of_crops .append (im_temp )
179187
180188 if show :
181189 plt .show ()
182190 print ('Number of generated images:' , count )
183191
184- return data_all_crops
192+ if self .batch_inference :
193+ return data_all_crops , batch_of_crops
194+ else :
195+ return data_all_crops
185196
186197 def _detect_objects (self ):
187198 """
@@ -207,3 +218,76 @@ def _detect_objects(self):
207218 crop .calculate_real_values ()
208219 if self .resize_initial_size :
209220 crop .resize_results ()
221+
222+ def _detect_objects_batch (self ):
223+ """
224+ Method to detect objects in batch of crop.
225+
226+ This method performs batch inference using the YOLO model,
227+ calculates real values, and optionally resizes the results.
228+
229+ Returns:
230+ None
231+ """
232+ crops , batch = self .crops
233+ self .crops = crops
234+ self ._calculate_batch_inference (
235+ batch ,
236+ self .crops ,
237+ self .model ,
238+ imgsz = self .imgsz ,
239+ conf = self .conf ,
240+ iou = self .iou ,
241+ segment = self .segment ,
242+ classes_list = self .classes_list ,
243+ memory_optimize = self .memory_optimize ,
244+ extra_args = self .inference_extra_args
245+ )
246+ for crop in self .crops :
247+ crop .calculate_real_values ()
248+ if self .resize_initial_size :
249+ crop .resize_results ()
250+
251+ def _calculate_batch_inference (
252+ self ,
253+ batch ,
254+ crops ,
255+ model ,
256+ imgsz = 640 ,
257+ conf = 0.35 ,
258+ iou = 0.7 ,
259+ segment = False ,
260+ classes_list = None ,
261+ memory_optimize = False ,
262+ extra_args = None ,
263+ ):
264+ # Perform inference
265+ extra_args = {} if extra_args is None else extra_args
266+ predictions = model .predict (
267+ batch ,
268+ imgsz = imgsz ,
269+ conf = conf ,
270+ iou = iou ,
271+ classes = classes_list ,
272+ verbose = False ,
273+ ** extra_args
274+ )
275+
276+ for pred , crop in zip (predictions , crops ):
277+
278+ # Get the bounding boxes and convert them to a list of lists
279+ crop .detected_xyxy = pred .boxes .xyxy .cpu ().int ().tolist ()
280+
281+ # Get the classes and convert them to a list
282+ crop .detected_cls = pred .boxes .cls .cpu ().int ().tolist ()
283+
284+ # Get the mask confidence scores
285+ crop .detected_conf = pred .boxes .conf .cpu ().numpy ()
286+
287+ if segment and len (crop .detected_cls ) != 0 :
288+ if memory_optimize :
289+ # Get the polygons
290+ crop .polygons = [mask .astype (np .uint16 ) for mask in pred .masks .xy ]
291+ else :
292+ # Get the masks
293+ crop .detected_masks = pred .masks .data .cpu ().numpy ()
0 commit comments