Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def perform_inference(self, image: np.ndarray):

if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}
if type(image) is list:

prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLOv8 expects numpy arrays to have BGR

prediction_result = self.model(image, **kwargs) # YOLOv8 expects numpy arrays to have BGR
else :
prediction_result = self.model(image[:, :, ::-1], **kwargs)
if self.has_mask:
if not prediction_result[0].masks:
prediction_result[0].masks = Masks(
Expand Down Expand Up @@ -109,7 +111,10 @@ def perform_inference(self, image: np.ndarray):
prediction_result = [result.boxes.data for result in prediction_result]

self._original_predictions = prediction_result
self._original_shape = image.shape
if type(image) == list:
self._original_shape = image[0].shape
else:
self._original_shape = image.shape

@property
def category_names(self):
Expand Down
50 changes: 32 additions & 18 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ def get_prediction(
durations_in_seconds = dict()

# read image as pil
image_as_pil = read_image_as_pil(image)
# image_as_pil = read_image_as_pil(image)
# get prediction
time_start = time.time()
detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
# detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
detection_model.perform_inference(image)
time_end = time.time() - time_start
durations_in_seconds["prediction"] = time_end

Expand All @@ -104,12 +105,10 @@ def get_prediction(
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list

object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list_per_image
# postprocess matching predictions
if postprocess is not None:
object_prediction_list = postprocess(object_prediction_list)

time_end = time.time() - time_start
durations_in_seconds["postprocess"] = time_end

Expand Down Expand Up @@ -142,6 +141,7 @@ def get_sliced_prediction(
auto_slice_resolution: bool = True,
slice_export_prefix: Optional[str] = None,
slice_dir: Optional[str] = None,
num_batch: int = 1
) -> PredictionResult:
"""
Function for slice image + get predicion for each slice + combine predictions in full image.
Expand Down Expand Up @@ -201,8 +201,8 @@ def get_sliced_prediction(
# for profiling
durations_in_seconds = dict()

# currently only 1 batch supported
num_batch = 1
# # currently only 1 batch supported
# num_batch = 1
# create slices from full image
time_start = time.time()
slice_image_result = slice_image(
Expand Down Expand Up @@ -236,7 +236,8 @@ def get_sliced_prediction(
)

# create prediction input
num_group = int(num_slices / num_batch)
# num_group = int(num_slices / num_batch)
num_group = math.ceil(num_slices / num_batch)
if verbose == 1 or verbose == 2:
tqdm.write(f"Performing prediction on {num_slices} slices.")
object_prediction_list = []
Expand All @@ -246,22 +247,31 @@ def get_sliced_prediction(
image_list = []
shift_amount_list = []
for image_ind in range(num_batch):
image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
if (group_ind * num_batch + image_ind) >= num_slices:
break
# image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
img_slice = slice_image_result.images[group_ind * num_batch + image_ind]
img_slice = img_slice[:,:,::-1]
image_list.append(img_slice)
shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
# perform batch prediction
num_full = len(image_list)
prediction_result = get_prediction(
image=image_list[0],
image=image_list,
detection_model=detection_model,
shift_amount=shift_amount_list[0],
full_shape=[
shift_amount=shift_amount_list,
full_shape=[[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
]] * num_full,
)

# convert sliced predictions to full predictions
for object_prediction in prediction_result.object_prediction_list:
if object_prediction: # if not empty
object_prediction_list.append(object_prediction.get_shifted_object_prediction())
for object_prediction_per in prediction_result.object_prediction_list:

if len(object_prediction_per) != 0: # if not empty
for object_prediction in object_prediction_per:
object_prediction_list.append(object_prediction.get_shifted_object_prediction())

# merge matching predictions during sliced prediction
if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
Expand All @@ -270,7 +280,7 @@ def get_sliced_prediction(
# perform standard prediction
if num_slices > 1 and perform_standard_pred:
prediction_result = get_prediction(
image=image,
image=[np.array(image)],
detection_model=detection_model,
shift_amount=[0, 0],
full_shape=[
Expand All @@ -279,7 +289,9 @@ def get_sliced_prediction(
],
postprocess=None,
)
object_prediction_list.extend(prediction_result.object_prediction_list)
if len(prediction_result.object_prediction_list) != 0:
for _predicion_result in prediction_result.object_prediction_list:
object_prediction_list.extend(_predicion_result)

Comment on lines +328 to 330
Copy link

Copilot AI Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Typo detected: '_predicion_result' should be renamed to '_prediction_result' for clarity.

Suggested change
for _predicion_result in prediction_result.object_prediction_list:
object_prediction_list.extend(_predicion_result)
for _prediction_result in prediction_result.object_prediction_list:
object_prediction_list.extend(_prediction_result)

Copilot uses AI. Check for mistakes.
# merge matching predictions
if len(object_prediction_list) > 1:
Expand Down Expand Up @@ -380,6 +392,7 @@ def predict(
verbose: int = 1,
return_dict: bool = False,
force_postprocess_type: bool = False,
num_batch: int = 1,
**kwargs,
):
"""
Expand Down Expand Up @@ -574,6 +587,7 @@ def predict(
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
verbose=1 if verbose else 0,
num_batch = num_batch,
)
object_prediction_list = prediction_result.object_prediction_list
if prediction_result.durations_in_seconds:
Expand Down
9 changes: 7 additions & 2 deletions sahi/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,13 @@ def __init__(
image: Union[Image.Image, str, np.ndarray],
durations_in_seconds: Dict[str, Any] = dict(),
):
self.image: Image.Image = read_image_as_pil(image)
self.image_width, self.image_height = self.image.size

if type(image) is list:
self.image = image
self.image_width, self.image_height = self.image[0].shape[:2]
else :
self.image: Image.Image = read_image_as_pil(image)
self.image_width, self.image_height = self.image.size
self.object_prediction_list: List[ObjectPrediction] = object_prediction_list
self.durations_in_seconds = durations_in_seconds

Expand Down
Loading