Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- name: Setup uv python package manager
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4
Expand Down
12 changes: 11 additions & 1 deletion docs/cli.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
# CLI Commands
## CLI commands

This page documents the SAHI CLI (the program exposed by `sahi/cli.py`).

### Top-level commands

- `predict` — run sliced/standard prediction pipeline and export results (images, COCO json, pickles)
- `predict-fiftyone` — run prediction and open results in FiftyOne
- `coco` — subgroup for COCO-format utilities (evaluate, analyse, convert, slice)
- `version` — print SAHI package version
- `env` — print environment and dependency versions

SAHI provides a comprehensive command-line interface for object detection tasks. This guide covers all available commands with detailed examples and options.

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ dependencies = [
"pybboxes==0.1.6",
"pillow>=8.2.0",
"pyyaml",
"fire",
"terminaltables",
"requests",
"click",
Expand Down
1 change: 0 additions & 1 deletion sahi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
except importlib_metadata.PackageNotFoundError:
__version__ = "development"


from sahi.annotation import BoundingBox, Category, Mask
from sahi.auto_model import AutoDetectionModel
from sahi.models.base import DetectionModel
Expand Down
42 changes: 37 additions & 5 deletions sahi/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import fire
import click

from sahi import __version__ as sahi_version
from sahi.predict import predict, predict_fiftyone
Expand All @@ -7,7 +7,15 @@
from sahi.scripts.coco_error_analysis import analyse
from sahi.scripts.coco_evaluation import evaluate
from sahi.scripts.slice_coco import slicer
from sahi.utils.import_utils import print_environment_info
from sahi.utils.cli_helper import make_click_command
from sahi.utils.package_utils import print_environment_info


@click.group(help="SAHI command-line utilities: slicing-aided high-resolution inference and COCO tools")
def cli():
"""Top-level click group for SAHI CLI."""
pass


coco_app = {
"evaluate": evaluate,
Expand All @@ -18,19 +26,43 @@
"yolov5": coco2yolo,
}


# Create wrapper functions with proper help text for commands that need it
def version_command():
"""Show SAHI version."""
click.echo(sahi_version)


def env_command():
"""Show environment information."""
print_environment_info()


sahi_app = {
"predict": predict,
"predict-fiftyone": predict_fiftyone,
"coco": coco_app,
"version": sahi_version,
"env": print_environment_info,
"version": version_command,
"env": env_command,
}


def app() -> None:
"""Cli app."""
fire.Fire(sahi_app)

for command_name, command_func in sahi_app.items():
if isinstance(command_func, dict):
# add subcommands (create a named Group with help text)
sub_cli = click.Group(command_name, help=f"{command_name} related commands")
for sub_command_name, sub_command_func in command_func.items():
sub_cli.add_command(make_click_command(sub_command_name, sub_command_func))
cli.add_command(sub_cli)
else:
cli.add_command(make_click_command(command_name, command_func))

cli()


if __name__ == "__main__":
# build the application (register commands) and run
app()
2 changes: 1 addition & 1 deletion sahi/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sahi.annotation import Category
from sahi.logger import logger
from sahi.prediction import ObjectPrediction
from sahi.utils.import_utils import check_requirements
from sahi.utils.package_utils import check_requirements
from sahi.utils.torch_utils import empty_cuda_cache, select_device


Expand Down
4 changes: 2 additions & 2 deletions sahi/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.import_utils import ensure_package_minimum_version
from sahi.utils.package_utils import check_package_minimum_version


class HuggingfaceDetectionModel(DetectionModel):
Expand All @@ -33,7 +33,7 @@ def __init__(
self._token = token
existing_packages = getattr(self, "required_packages", None) or []
self.required_packages = [*list(existing_packages), "torch", "transformers"]
ensure_package_minimum_version("transformers", "4.42.0")
check_package_minimum_version("transformers", "4.42.0", raise_error=True)
super().__init__(
model_path,
model,
Expand Down
2 changes: 1 addition & 1 deletion sahi/models/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.cv import get_bbox_from_bool_mask, get_coco_segmentation_from_bool_mask
from sahi.utils.import_utils import check_requirements
from sahi.utils.package_utils import check_requirements

check_requirements(["torch", "mmdet", "mmcv", "mmengine"])

Expand Down
2 changes: 1 addition & 1 deletion sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.cv import get_coco_segmentation_from_bool_mask
from sahi.utils.import_utils import check_requirements
from sahi.utils.package_utils import check_requirements


class UltralyticsDetectionModel(DetectionModel):
Expand Down
4 changes: 2 additions & 2 deletions sahi/models/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.import_utils import check_package_minimum_version
from sahi.utils.package_utils import check_package_minimum_version


class Yolov5DetectionModel(DetectionModel):
Expand Down Expand Up @@ -77,7 +77,7 @@ def has_mask(self):

@property
def category_names(self):
if check_package_minimum_version("yolov5", "6.2.0"):
if check_package_minimum_version("yolov5", "6.2.0", raise_error=False):
return list(self.model.names.values())
else:
return self.model.names
Expand Down
3 changes: 0 additions & 3 deletions sahi/postprocess/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sahi.logger import logger
from sahi.postprocess.utils import ObjectPredictionList, has_match, merge_object_prediction_pair
from sahi.prediction import ObjectPrediction
from sahi.utils.import_utils import check_requirements


def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_threshold: float = 0.5):
Expand Down Expand Up @@ -458,8 +457,6 @@ def __init__(
self.class_agnostic = class_agnostic
self.match_metric = match_metric

check_requirements(["torch"])

def __call__(self, predictions: list[ObjectPrediction]):
raise NotImplementedError()

Expand Down
2 changes: 1 addition & 1 deletion sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
visualize_object_predictions,
)
from sahi.utils.file import Path, increment_path, list_files, save_json, save_pickle
from sahi.utils.import_utils import check_requirements
from sahi.utils.package_utils import check_requirements

POSTPROCESS_NAME_TO_CLASS = {
"GREEDYNMM": GreedyNMMPostprocess,
Expand Down
8 changes: 2 additions & 6 deletions sahi/scripts/coco2fiftyone.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import time
from pathlib import Path

import fire

from sahi.utils.file import load_json


Expand All @@ -13,6 +11,8 @@ def main(
iou_thresh: float = 0.5,
):
"""
Launch fiftyone app with coco dataset and coco results.

Args:
image_dir (str): directory for coco images
dataset_json_path (str): file path for the coco dataset json file
Expand Down Expand Up @@ -75,7 +75,3 @@ def main(
print(f"SAHI has successfully launched a Fiftyone app at http://localhost:{fo.config.default_app_port}")
while 1:
time.sleep(3)


if __name__ == "__main__":
fire.Fire(main)
29 changes: 15 additions & 14 deletions sahi/scripts/coco2yolo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import fire

from sahi.utils.coco import Coco
from sahi.utils.file import Path, increment_path

Expand All @@ -15,17 +13,24 @@ def main(
seed: int = 1,
disable_symlink=False,
):
"""
"""Convert COCO dataset to YOLO format.

Args:
images_dir (str): directory for coco images
dataset_json_path (str): file path for the coco json file to be converted
train_split (float or int): set the training split ratio
project (str): save results to project/name
name (str): save results to project/name"
seed (int): fix the seed for reproducibility
disable_symlink (bool): required in google colab env
image_dir: Directory containing COCO images
dataset_json_path: Path to the COCO JSON annotation file
train_split: Training split ratio (0.0 to 1.0)
project: Project directory for output
name: Experiment name for output subdirectory
seed: Random seed for reproducibility
disable_symlink: Disable symbolic links (required in some environments)
"""

# Validate required parameters
if image_dir is None:
raise ValueError("image_dir is required")
if dataset_json_path is None:
raise ValueError("dataset_json_path is required")

# increment run
save_dir = Path(increment_path(Path(project) / name, exist_ok=False))
# load coco dict
Expand All @@ -42,7 +47,3 @@ def main(
)

print(f"COCO to YOLO conversion results are successfully exported to {save_dir}")


if __name__ == "__main__":
fire.Fire(main)
7 changes: 2 additions & 5 deletions sahi/scripts/coco_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from multiprocessing import Pool
from pathlib import Path

import fire
import numpy as np

from sahi.logger import logger
Expand Down Expand Up @@ -452,6 +451,8 @@ def analyse(
return_dict: bool = False,
) -> dict | None:
"""
Perform COCO error analysis and export result plots to out_dir.

Args:
dataset_json_path (str): file path for the coco dataset json file
result_json_paths (str): file path for the coco result json file
Expand Down Expand Up @@ -480,7 +481,3 @@ def analyse(
)
if return_dict:
return result


if __name__ == "__main__":
fire.Fire(analyse)
9 changes: 3 additions & 6 deletions sahi/scripts/coco_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pathlib import Path
from typing import Literal

import fire
import numpy as np
from terminaltables import AsciiTable

Expand Down Expand Up @@ -361,7 +360,8 @@ def evaluate(
areas: list[int] = [1024, 9216, 10000000000],
return_dict: bool = False,
):
"""
"""COCO evaluation entrypoint.

Args:
dataset_json_path (str): file path for the coco dataset json file
result_json_path (str): file path for the coco result json file
Expand All @@ -373,6 +373,7 @@ def evaluate(
areas (List[int]): area regions for coco evaluation calculations
return_dict (bool): If True, returns a dict with 'eval_results' 'export_path' fields.
"""

try:
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
Expand All @@ -396,7 +397,3 @@ def evaluate(
)
if return_dict:
return result


if __name__ == "__main__":
fire.Fire(evaluate)
11 changes: 0 additions & 11 deletions sahi/scripts/predict.py

This file was deleted.

11 changes: 0 additions & 11 deletions sahi/scripts/predict_fiftyone.py

This file was deleted.

8 changes: 2 additions & 6 deletions sahi/scripts/slice_coco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os

import fire

from sahi.slicing import slice_coco
from sahi.utils.file import Path, save_json

Expand All @@ -16,6 +14,8 @@ def slicer(
min_area_ratio: float = 0.1,
):
"""
Slice COCO dataset images and annotations.

Args:
image_dir (str): directory for coco images
dataset_json_path (str): file path for the coco dataset json file
Expand Down Expand Up @@ -61,7 +61,3 @@ def slicer(
output_coco_annotation_file_path = os.path.join(output_dir, sliced_coco_name + ".json")
save_json(coco_dict, output_coco_annotation_file_path)
print(f"Sliced dataset for 'slice_size: {slice_size}' is exported to {output_dir}")


if __name__ == "__main__":
fire.Fire(slice)
Loading