Skip to content

Commit e9e5954

Browse files
committed
new batch infer
1 parent b7d0ffd commit e9e5954

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

patched_yolo_infer/nodes/MakeCropsDetectThem.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
model=None,
7171
memory_optimize=True,
7272
inference_extra_args=None,
73+
batch_inference=False,
7374
) -> None:
7475
if model is None:
7576
self.model = YOLO(model_path) # Load the model from the specified path
@@ -91,6 +92,7 @@ def __init__(
9192
self.memory_optimize = memory_optimize # memory opimization option for segmentation
9293
self.class_names_dict = self.model.names # dict with human-readable class names
9394
self.inference_extra_args = inference_extra_args # dict with extra ultralytics inference parameters
95+
self.batch_inference = batch_inference
9496

9597
self.crops = self.get_crops_xy(
9698
self.image,
@@ -100,7 +102,10 @@ def __init__(
100102
overlap_y=self.overlap_y,
101103
show=self.show_crops,
102104
)
103-
self._detect_objects()
105+
if self.batch_inference:
106+
self._detect_objects_batch()
107+
else:
108+
self._detect_objects()
104109

105110
def get_crops_xy(
106111
self,
@@ -141,6 +146,7 @@ def get_crops_xy(
141146
x_new = round((x_steps-1) * (shape_x * cross_koef_x) + shape_x)
142147
image_innitial = image_full.copy()
143148
image_full = cv2.resize(image_full, (x_new, y_new))
149+
batch_of_crops = []
144150

145151
if show:
146152
plt.figure(figsize=[x_steps*0.9, y_steps*0.9])
@@ -176,12 +182,17 @@ def get_crops_xy(
176182
x_start=x_start,
177183
y_start=y_start,
178184
))
185+
if self.batch_inference:
186+
batch_of_crops.append(im_temp)
179187

180188
if show:
181189
plt.show()
182190
print('Number of generated images:', count)
183191

184-
return data_all_crops
192+
if self.batch_inference:
193+
return data_all_crops, batch_of_crops
194+
else:
195+
return data_all_crops
185196

186197
def _detect_objects(self):
187198
"""
@@ -207,3 +218,76 @@ def _detect_objects(self):
207218
crop.calculate_real_values()
208219
if self.resize_initial_size:
209220
crop.resize_results()
221+
222+
def _detect_objects_batch(self):
223+
"""
224+
Method to detect objects in batch of crop.
225+
226+
This method performs batch inference using the YOLO model,
227+
calculates real values, and optionally resizes the results.
228+
229+
Returns:
230+
None
231+
"""
232+
crops, batch = self.crops
233+
self.crops = crops
234+
self._calculate_batch_inference(
235+
batch,
236+
self.crops,
237+
self.model,
238+
imgsz=self.imgsz,
239+
conf=self.conf,
240+
iou=self.iou,
241+
segment=self.segment,
242+
classes_list=self.classes_list,
243+
memory_optimize=self.memory_optimize,
244+
extra_args=self.inference_extra_args
245+
)
246+
for crop in self.crops:
247+
crop.calculate_real_values()
248+
if self.resize_initial_size:
249+
crop.resize_results()
250+
251+
def _calculate_batch_inference(
252+
self,
253+
batch,
254+
crops,
255+
model,
256+
imgsz=640,
257+
conf=0.35,
258+
iou=0.7,
259+
segment=False,
260+
classes_list=None,
261+
memory_optimize=False,
262+
extra_args=None,
263+
):
264+
# Perform inference
265+
extra_args = {} if extra_args is None else extra_args
266+
predictions = model.predict(
267+
batch,
268+
imgsz=imgsz,
269+
conf=conf,
270+
iou=iou,
271+
classes=classes_list,
272+
verbose=False,
273+
**extra_args
274+
)
275+
276+
for pred, crop in zip(predictions, crops):
277+
278+
# Get the bounding boxes and convert them to a list of lists
279+
crop.detected_xyxy = pred.boxes.xyxy.cpu().int().tolist()
280+
281+
# Get the classes and convert them to a list
282+
crop.detected_cls = pred.boxes.cls.cpu().int().tolist()
283+
284+
# Get the mask confidence scores
285+
crop.detected_conf = pred.boxes.conf.cpu().numpy()
286+
287+
if segment and len(crop.detected_cls) != 0:
288+
if memory_optimize:
289+
# Get the polygons
290+
crop.polygons = [mask.astype(np.uint16) for mask in pred.masks.xy]
291+
else:
292+
# Get the masks
293+
crop.detected_masks = pred.masks.data.cpu().numpy()

0 commit comments

Comments
 (0)