diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c86024183..41ad2cc47 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -93,7 +93,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] steps: - name: Setup uv python package manager uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/docs/cli.md b/docs/cli.md index 54ade3f53..2da1a293b 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -1,4 +1,14 @@ -# CLI Commands +## CLI commands + +This page documents the SAHI CLI (the program exposed by `sahi/cli.py`). + +### Top-level commands + +- `predict` — run sliced/standard prediction pipeline and export results (images, COCO json, pickles) +- `predict-fiftyone` — run prediction and open results in FiftyOne +- `coco` — subgroup for COCO-format utilities (evaluate, analyse, convert, slice) +- `version` — print SAHI package version +- `env` — print environment and dependency versions SAHI provides a comprehensive command-line interface for object detection tasks. This guide covers all available commands with detailed examples and options. diff --git a/pyproject.toml b/pyproject.toml index 2331e2a68..36932760f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ dependencies = [ "pybboxes==0.1.6", "pillow>=8.2.0", "pyyaml", - "fire", "terminaltables", "requests", "click", diff --git a/sahi/__init__.py b/sahi/__init__.py index 75e16e15c..72ed39a0a 100644 --- a/sahi/__init__.py +++ b/sahi/__init__.py @@ -6,7 +6,6 @@ except importlib_metadata.PackageNotFoundError: __version__ = "development" - from sahi.annotation import BoundingBox, Category, Mask from sahi.auto_model import AutoDetectionModel from sahi.models.base import DetectionModel diff --git a/sahi/cli.py b/sahi/cli.py index e086f4e38..e4864efe4 100644 --- a/sahi/cli.py +++ b/sahi/cli.py @@ -1,4 +1,4 @@ -import fire +import click from sahi import __version__ as sahi_version from sahi.predict import predict, predict_fiftyone @@ -7,7 +7,15 @@ from sahi.scripts.coco_error_analysis import analyse from sahi.scripts.coco_evaluation import evaluate from sahi.scripts.slice_coco import slicer -from sahi.utils.import_utils import print_environment_info +from sahi.utils.cli_helper import make_click_command +from sahi.utils.package_utils import print_environment_info + + +@click.group(help="SAHI command-line utilities: slicing-aided high-resolution inference and COCO tools") +def cli(): + """Top-level click group for SAHI CLI.""" + pass + coco_app = { "evaluate": evaluate, @@ -18,19 +26,43 @@ "yolov5": coco2yolo, } + +# Create wrapper functions with proper help text for commands that need it +def version_command(): + """Show SAHI version.""" + click.echo(sahi_version) + + +def env_command(): + """Show environment information.""" + print_environment_info() + + sahi_app = { "predict": predict, "predict-fiftyone": predict_fiftyone, "coco": coco_app, - "version": sahi_version, - "env": print_environment_info, + "version": version_command, + "env": env_command, } def app() -> None: """Cli app.""" - fire.Fire(sahi_app) + + for command_name, command_func in sahi_app.items(): + if isinstance(command_func, dict): + # add subcommands (create a named Group with help text) + sub_cli = click.Group(command_name, help=f"{command_name} related commands") + for sub_command_name, sub_command_func in command_func.items(): + sub_cli.add_command(make_click_command(sub_command_name, sub_command_func)) + cli.add_command(sub_cli) + else: + cli.add_command(make_click_command(command_name, command_func)) + + cli() if __name__ == "__main__": + # build the application (register commands) and run app() diff --git a/sahi/models/base.py b/sahi/models/base.py index b0e036c2d..26e2808b7 100644 --- a/sahi/models/base.py +++ b/sahi/models/base.py @@ -7,7 +7,7 @@ from sahi.annotation import Category from sahi.logger import logger from sahi.prediction import ObjectPrediction -from sahi.utils.import_utils import check_requirements +from sahi.utils.package_utils import check_requirements from sahi.utils.torch_utils import empty_cuda_cache, select_device diff --git a/sahi/models/huggingface.py b/sahi/models/huggingface.py index c78fda6e6..d0ff4f93d 100644 --- a/sahi/models/huggingface.py +++ b/sahi/models/huggingface.py @@ -9,7 +9,7 @@ from sahi.models.base import DetectionModel from sahi.prediction import ObjectPrediction from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list -from sahi.utils.import_utils import ensure_package_minimum_version +from sahi.utils.package_utils import check_package_minimum_version class HuggingfaceDetectionModel(DetectionModel): @@ -33,7 +33,7 @@ def __init__( self._token = token existing_packages = getattr(self, "required_packages", None) or [] self.required_packages = [*list(existing_packages), "torch", "transformers"] - ensure_package_minimum_version("transformers", "4.42.0") + check_package_minimum_version("transformers", "4.42.0", raise_error=True) super().__init__( model_path, model, diff --git a/sahi/models/mmdet.py b/sahi/models/mmdet.py index 0304347d9..8ccd0a621 100644 --- a/sahi/models/mmdet.py +++ b/sahi/models/mmdet.py @@ -9,7 +9,7 @@ from sahi.prediction import ObjectPrediction from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list from sahi.utils.cv import get_bbox_from_bool_mask, get_coco_segmentation_from_bool_mask -from sahi.utils.import_utils import check_requirements +from sahi.utils.package_utils import check_requirements check_requirements(["torch", "mmdet", "mmcv", "mmengine"]) diff --git a/sahi/models/ultralytics.py b/sahi/models/ultralytics.py index 1274e557a..dd362803c 100644 --- a/sahi/models/ultralytics.py +++ b/sahi/models/ultralytics.py @@ -10,7 +10,7 @@ from sahi.prediction import ObjectPrediction from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list from sahi.utils.cv import get_coco_segmentation_from_bool_mask -from sahi.utils.import_utils import check_requirements +from sahi.utils.package_utils import check_requirements class UltralyticsDetectionModel(DetectionModel): diff --git a/sahi/models/yolov5.py b/sahi/models/yolov5.py index a46e21345..74b1f6fba 100644 --- a/sahi/models/yolov5.py +++ b/sahi/models/yolov5.py @@ -8,7 +8,7 @@ from sahi.models.base import DetectionModel from sahi.prediction import ObjectPrediction from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list -from sahi.utils.import_utils import check_package_minimum_version +from sahi.utils.package_utils import check_package_minimum_version class Yolov5DetectionModel(DetectionModel): @@ -77,7 +77,7 @@ def has_mask(self): @property def category_names(self): - if check_package_minimum_version("yolov5", "6.2.0"): + if check_package_minimum_version("yolov5", "6.2.0", raise_error=False): return list(self.model.names.values()) else: return self.model.names diff --git a/sahi/postprocess/combine.py b/sahi/postprocess/combine.py index 37009abb2..a087493c1 100644 --- a/sahi/postprocess/combine.py +++ b/sahi/postprocess/combine.py @@ -7,7 +7,6 @@ from sahi.logger import logger from sahi.postprocess.utils import ObjectPredictionList, has_match, merge_object_prediction_pair from sahi.prediction import ObjectPrediction -from sahi.utils.import_utils import check_requirements def batched_nms(predictions: torch.tensor, match_metric: str = "IOU", match_threshold: float = 0.5): @@ -458,8 +457,6 @@ def __init__( self.class_agnostic = class_agnostic self.match_metric = match_metric - check_requirements(["torch"]) - def __call__(self, predictions: list[ObjectPrediction]): raise NotImplementedError() diff --git a/sahi/predict.py b/sahi/predict.py index 6a16c2a51..9d676e0d6 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -32,7 +32,7 @@ visualize_object_predictions, ) from sahi.utils.file import Path, increment_path, list_files, save_json, save_pickle -from sahi.utils.import_utils import check_requirements +from sahi.utils.package_utils import check_requirements POSTPROCESS_NAME_TO_CLASS = { "GREEDYNMM": GreedyNMMPostprocess, diff --git a/sahi/scripts/coco2fiftyone.py b/sahi/scripts/coco2fiftyone.py index 0332242e6..065801501 100644 --- a/sahi/scripts/coco2fiftyone.py +++ b/sahi/scripts/coco2fiftyone.py @@ -1,8 +1,6 @@ import time from pathlib import Path -import fire - from sahi.utils.file import load_json @@ -13,6 +11,8 @@ def main( iou_thresh: float = 0.5, ): """ + Launch fiftyone app with coco dataset and coco results. + Args: image_dir (str): directory for coco images dataset_json_path (str): file path for the coco dataset json file @@ -75,7 +75,3 @@ def main( print(f"SAHI has successfully launched a Fiftyone app at http://localhost:{fo.config.default_app_port}") while 1: time.sleep(3) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/sahi/scripts/coco2yolo.py b/sahi/scripts/coco2yolo.py index fe75db679..69a5d7add 100644 --- a/sahi/scripts/coco2yolo.py +++ b/sahi/scripts/coco2yolo.py @@ -1,7 +1,5 @@ from __future__ import annotations -import fire - from sahi.utils.coco import Coco from sahi.utils.file import Path, increment_path @@ -15,17 +13,24 @@ def main( seed: int = 1, disable_symlink=False, ): - """ + """Convert COCO dataset to YOLO format. + Args: - images_dir (str): directory for coco images - dataset_json_path (str): file path for the coco json file to be converted - train_split (float or int): set the training split ratio - project (str): save results to project/name - name (str): save results to project/name" - seed (int): fix the seed for reproducibility - disable_symlink (bool): required in google colab env + image_dir: Directory containing COCO images + dataset_json_path: Path to the COCO JSON annotation file + train_split: Training split ratio (0.0 to 1.0) + project: Project directory for output + name: Experiment name for output subdirectory + seed: Random seed for reproducibility + disable_symlink: Disable symbolic links (required in some environments) """ + # Validate required parameters + if image_dir is None: + raise ValueError("image_dir is required") + if dataset_json_path is None: + raise ValueError("dataset_json_path is required") + # increment run save_dir = Path(increment_path(Path(project) / name, exist_ok=False)) # load coco dict @@ -42,7 +47,3 @@ def main( ) print(f"COCO to YOLO conversion results are successfully exported to {save_dir}") - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/sahi/scripts/coco_error_analysis.py b/sahi/scripts/coco_error_analysis.py index 906711d1f..2944ae38e 100644 --- a/sahi/scripts/coco_error_analysis.py +++ b/sahi/scripts/coco_error_analysis.py @@ -6,7 +6,6 @@ from multiprocessing import Pool from pathlib import Path -import fire import numpy as np from sahi.logger import logger @@ -452,6 +451,8 @@ def analyse( return_dict: bool = False, ) -> dict | None: """ + Perform COCO error analysis and export result plots to out_dir. + Args: dataset_json_path (str): file path for the coco dataset json file result_json_paths (str): file path for the coco result json file @@ -480,7 +481,3 @@ def analyse( ) if return_dict: return result - - -if __name__ == "__main__": - fire.Fire(analyse) diff --git a/sahi/scripts/coco_evaluation.py b/sahi/scripts/coco_evaluation.py index 2b08a8723..36a110484 100644 --- a/sahi/scripts/coco_evaluation.py +++ b/sahi/scripts/coco_evaluation.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Literal -import fire import numpy as np from terminaltables import AsciiTable @@ -361,7 +360,8 @@ def evaluate( areas: list[int] = [1024, 9216, 10000000000], return_dict: bool = False, ): - """ + """COCO evaluation entrypoint. + Args: dataset_json_path (str): file path for the coco dataset json file result_json_path (str): file path for the coco result json file @@ -373,6 +373,7 @@ def evaluate( areas (List[int]): area regions for coco evaluation calculations return_dict (bool): If True, returns a dict with 'eval_results' 'export_path' fields. """ + try: from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval @@ -396,7 +397,3 @@ def evaluate( ) if return_dict: return result - - -if __name__ == "__main__": - fire.Fire(evaluate) diff --git a/sahi/scripts/predict.py b/sahi/scripts/predict.py deleted file mode 100644 index 6d93960dc..000000000 --- a/sahi/scripts/predict.py +++ /dev/null @@ -1,11 +0,0 @@ -import fire - -from sahi.predict import predict - - -def main(): - fire.Fire(predict) - - -if __name__ == "__main__": - main() diff --git a/sahi/scripts/predict_fiftyone.py b/sahi/scripts/predict_fiftyone.py deleted file mode 100644 index ec3aa0d2e..000000000 --- a/sahi/scripts/predict_fiftyone.py +++ /dev/null @@ -1,11 +0,0 @@ -import fire - -from sahi.predict import predict_fiftyone - - -def main(): - fire.Fire(predict_fiftyone) - - -if __name__ == "__main__": - main() diff --git a/sahi/scripts/slice_coco.py b/sahi/scripts/slice_coco.py index e454ab925..8ef75a636 100644 --- a/sahi/scripts/slice_coco.py +++ b/sahi/scripts/slice_coco.py @@ -1,7 +1,5 @@ import os -import fire - from sahi.slicing import slice_coco from sahi.utils.file import Path, save_json @@ -16,6 +14,8 @@ def slicer( min_area_ratio: float = 0.1, ): """ + Slice COCO dataset images and annotations. + Args: image_dir (str): directory for coco images dataset_json_path (str): file path for the coco dataset json file @@ -61,7 +61,3 @@ def slicer( output_coco_annotation_file_path = os.path.join(output_dir, sliced_coco_name + ".json") save_json(coco_dict, output_coco_annotation_file_path) print(f"Sliced dataset for 'slice_size: {slice_size}' is exported to {output_dir}") - - -if __name__ == "__main__": - fire.Fire(slice) diff --git a/sahi/utils/cli_helper.py b/sahi/utils/cli_helper.py new file mode 100644 index 000000000..f955bcdd1 --- /dev/null +++ b/sahi/utils/cli_helper.py @@ -0,0 +1,89 @@ +import inspect + +import click + + +def _make_callback(obj): + """Return a callable suitable for click.Command: + - if obj is callable, call it with whatever click passes; + - otherwise print the object (e.g. version string). + + """ + if callable(obj): + + def _cb(*args, **kwargs): + return obj(*args, **kwargs) + else: + + def _cb(*args, **kwargs): + click.echo(str(obj)) + + return _cb + + +def _click_params_from_signature(func): + """Create a list of click.Parameter (Argument/Option) objects from a Python + callable's signature. This provides a lightweight automatic mapping so + CLI options are available without manually writing decorators. + + Rules (simple, pragmatic): + - positional parameters without default -> click.Argument (required) + - parameters with a default -> click.Option named --param-name + - bool defaults -> is_flag option + - list/tuple defaults -> multiple option + - skip *args/**kwargs and (self, cls) + - use annotation or default value to infer type when possible + """ + params = [] + sig = inspect.signature(func) + for name, p in sig.parameters.items(): + # skip common unrepresentable params + if name in ("self", "cls"): + continue + if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + # skip *args/**kwargs + continue + + opt_name = f"--{name}" + + if p.default is inspect._empty: + # required option (no default value) + param_type = None + if p.annotation is not inspect._empty and p.annotation in (int, float, str, bool): + param_type = p.annotation + else: + param_type = str + params.append(click.Option([opt_name], required=True, type=param_type, help=" ")) + else: + # boolean flags + if isinstance(p.default, bool): + params.append(click.Option([opt_name], is_flag=True, default=p.default, help=f" default={p.default}")) + # lists/tuples -> multiple + elif isinstance(p.default, (list, tuple)): + params.append(click.Option([opt_name], multiple=True, default=tuple(p.default), help=" multiple")) + else: + # infer type from annotation or default value + param_type = None + if p.annotation is not inspect._empty and p.annotation in (int, float, str, bool): + param_type = p.annotation + elif p.default is not None: + param_type = type(p.default) + else: + param_type = str + params.append( + click.Option([opt_name], default=p.default, type=param_type, help=f" default={p.default}") + ) + return params + + +def make_click_command(name, func): + """Build a click.Command for `func`, auto-generating params from signature if callable.""" + params = _click_params_from_signature(func) if callable(func) else [] + # use function docstring as help when available, but only the first line (summary) + help_text = None + if callable(func): + docstring = func.__doc__ if getattr(func, "__doc__", None) else None + if docstring: + # Extract only the first line of the docstring for cleaner CLI help + help_text = docstring.strip().split("\n")[0] + return click.Command(name, params=params, callback=_make_callback(func), help=help_text) diff --git a/sahi/utils/fiftyone.py b/sahi/utils/fiftyone.py index b3a469c19..700b1af96 100644 --- a/sahi/utils/fiftyone.py +++ b/sahi/utils/fiftyone.py @@ -2,7 +2,7 @@ import subprocess import sys -from sahi.utils.import_utils import is_available +from sahi.utils.package_utils import is_available if is_available("fiftyone"): # to fix https://github.com/voxel51/fiftyone/issues/845 diff --git a/sahi/utils/import_utils.py b/sahi/utils/import_utils.py deleted file mode 100644 index 0446a975e..000000000 --- a/sahi/utils/import_utils.py +++ /dev/null @@ -1,93 +0,0 @@ -import importlib.util - -from sahi.logger import logger - -# adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py - - -def get_package_info(package_name: str, verbose: bool = True): - """Returns the package version as a string and the package name as a string.""" - _is_available = is_available(package_name) - - if _is_available: - try: - import importlib.metadata as _importlib_metadata - - _version = _importlib_metadata.version(package_name) - except (ModuleNotFoundError, AttributeError): - try: - _version = importlib.import_module(package_name).__version__ - except AttributeError: - _version = "unknown" - if verbose: - logger.pkg_info(f"{package_name} version {_version} is available.") - else: - _version = "N/A" - - return _is_available, _version - - -def print_environment_info(): - get_package_info("torch") - get_package_info("torchvision") - get_package_info("tensorflow") - get_package_info("tensorflow-hub") - get_package_info("ultralytics") - get_package_info("yolov5") - get_package_info("mmdet") - get_package_info("mmcv") - get_package_info("detectron2") - get_package_info("transformers") - get_package_info("timm") - get_package_info("fiftyone") - get_package_info("pillow") - get_package_info("opencv-python") - - -def is_available(module_name: str): - return importlib.util.find_spec(module_name) is not None - - -def check_requirements(package_names): - """Raise error if module is not installed.""" - missing_packages = [] - for package_name in package_names: - if importlib.util.find_spec(package_name) is None: - missing_packages.append(package_name) - if missing_packages: - raise ImportError(f"The following packages are required to use this module: {missing_packages}") - yield - - -def check_package_minimum_version(package_name: str, minimum_version: str, verbose=False): - """Raise error if module version is not compatible.""" - from packaging import version - - _is_available, _version = get_package_info(package_name, verbose=verbose) - if _is_available: - if _version == "unknown": - logger.warning( - f"Could not determine version of {package_name}. Assuming version {minimum_version} is compatible." - ) - else: - if version.parse(_version) < version.parse(minimum_version): - return False - return True - - -def ensure_package_minimum_version(package_name: str, minimum_version: str, verbose=False): - """Raise error if module version is not compatible.""" - from packaging import version - - _is_available, _version = get_package_info(package_name, verbose=verbose) - if _is_available: - if _version == "unknown": - logger.warning( - f"Could not determine version of {package_name}. Assuming version {minimum_version} is compatible." - ) - else: - if version.parse(_version) < version.parse(minimum_version): - raise ImportError( - f"Please upgrade {package_name} to version {minimum_version} or higher to use this module." - ) - yield diff --git a/sahi/utils/package_utils.py b/sahi/utils/package_utils.py new file mode 100644 index 000000000..35a9d9a21 --- /dev/null +++ b/sahi/utils/package_utils.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import importlib.util +import platform +from collections.abc import Generator +from typing import Any + +from sahi.logger import logger + +# adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py + + +# Mapping from package names to their import names +PACKAGE_TO_MODULE_MAP = { + "opencv-python": "cv2", + "opencv-python-headless": "cv2", + "pillow": "PIL", +} + + +def get_package_info(package_name: str) -> tuple[bool, str]: + """Returns the package version as a string and the package name as a string.""" + + if package_name not in PACKAGE_TO_MODULE_MAP: + module_name = package_name + else: + module_name = PACKAGE_TO_MODULE_MAP.get(package_name, package_name) + + _is_available = is_available(module_name) + + if _is_available: + try: + _version = importlib.import_module(module_name).__version__ + except (ModuleNotFoundError, AttributeError): + try: + _version = importlib.import_module(package_name).__version__ + except AttributeError: + _version = "unknown" + else: + _version = "N/A" + + logger.pkg_info(f"{package_name} version {_version} is installed.") + return _is_available, _version + + +def sys_info(): + logger.pkg_info("System Information:") + logger.pkg_info(f"Python version: {platform.python_version()}") + logger.pkg_info(f"Platform: {platform.platform().capitalize()}") + logger.pkg_info(f"Processor: {platform.processor().capitalize()}") + logger.pkg_info(f"Machine: {platform.machine().capitalize()}") + logger.pkg_info(f"System: {platform.system().capitalize()}") + logger.pkg_info(f"Release: {platform.release().capitalize()}") + logger.pkg_info(f"Version: {platform.version().capitalize()}") + logger.pkg_info(f"Architecture: {platform.architecture()[0].capitalize()}") + + +def print_environment_info() -> None: + sys_info() + logger.pkg_info("=== Package Information ===") + get_package_info("torch") + get_package_info("torchvision") + get_package_info("tensorflow") + get_package_info("tensorflow-hub") + get_package_info("ultralytics") + get_package_info("yolov5") + get_package_info("mmdet") + get_package_info("mmcv") + get_package_info("detectron2") + get_package_info("transformers") + get_package_info("timm") + get_package_info("fiftyone") + get_package_info("pillow") + get_package_info("opencv-python") + + +def is_available(module_name: str) -> bool: + return importlib.util.find_spec(module_name) is not None + + +def check_requirements(package_names: list[str]) -> Generator[None, Any, Any]: + """Raise error if module is not installed.""" + missing_packages = [] + for package_name in package_names: + if importlib.util.find_spec(package_name) is None: + missing_packages.append(package_name) + if missing_packages: + raise ImportError(f"The following packages are required to use this module: {missing_packages}") + yield + + +def _parse_version(version_str: str) -> tuple[int, ...]: + """Simple version parser that converts '1.2.3' to (1, 2, 3) for comparison.""" + try: + return tuple(int(x) for x in version_str.split(".")) + except (ValueError, AttributeError): + return (0,) # Default to 0.0 for unparseable versions + + +def check_package_minimum_version(package_name: str, minimum_version: str, raise_error: bool = False) -> bool: + """Check if module version meets minimum requirement. + + Args: + package_name: Name of the package to check + minimum_version: Minimum required version (e.g., '1.2.3') + raise_error: If True, raises ImportError when version is too low + + Returns: + True if version is compatible, False otherwise + + Raises: + ImportError: If raise_error=True and version is incompatible + """ + _is_available, _version = get_package_info(package_name) + + if not _is_available: + if raise_error: + raise ImportError(f"Package {package_name} is not installed.") + return False + + if _version == "unknown": + logger.warning( + f"Could not determine version of {package_name}. Assuming version {minimum_version} is compatible." + ) + return True + + if _version == "N/A": + if raise_error: + raise ImportError(f"Package {package_name} is not available.") + return False + + # Compare versions using simple tuple comparison + current_version = _parse_version(_version) + required_version = _parse_version(minimum_version) + + is_compatible = current_version >= required_version + + if not is_compatible and raise_error: + raise ImportError( + f"Please upgrade {package_name} to version {minimum_version} or higher. Current version: {_version}" + ) + + return is_compatible diff --git a/scripts/benchmark_combine.py b/scripts/benchmark_combine.py new file mode 100644 index 000000000..ba74bc863 --- /dev/null +++ b/scripts/benchmark_combine.py @@ -0,0 +1,84 @@ +"""Micro-benchmark for combine.py NMS/NMM functions. + +Runs simple timings for nms, nmm and greedy_nmm (and batched counterparts) with random boxes +and prints elapsed time. Designed to be run locally by developers. +""" + +import argparse +import random +import time + +import torch + +from sahi.postprocess.combine import ( + batched_greedy_nmm, + batched_nmm, + batched_nms, + greedy_nmm, + nmm, + nms, +) + + +def random_boxes(n, classes=1): + boxes = [] + for i in range(n): + x1 = random.random() * 10000 + y1 = random.random() * 10000 + w = random.random() * 50 + h = random.random() * 50 + x2 = x1 + max(1.0, w) + y2 = y1 + max(1.0, h) + score = random.random() + cid = random.randint(1, classes) + boxes.append([x1, y1, x2, y2, score, cid]) + return torch.tensor(boxes, dtype=torch.float32) + + +def time_fn(fn, arg, repeat=3): + times = [] + for _ in range(repeat): + t0 = time.time() + fn(arg) + times.append(time.time() - t0) + return min(times), sum(times) / len(times) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--sizes", nargs="*", type=int, default=[100, 500, 1000]) + parser.add_argument("--classes", type=int, default=5) + args = parser.parse_args() + + print("Benchmarking NMS/NMM implementations") + for n in args.sizes: + print(f"\nInput size: {n}") + data = random_boxes(n, classes=args.classes) + + # nms (class-agnostic) + mn, avg = time_fn(nms, data) + print(f"nms: min={mn:.4f}s avg={avg:.4f}s") + + # batched_nms (class-aware) + mn, avg = time_fn(batched_nms, data) + print(f"batched_nms: min={mn:.4f}s avg={avg:.4f}s") + + # nmm + mn, avg = time_fn(nmm, data) + print(f"nmm: min={mn:.4f}s avg={avg:.4f}s") + + # batched_nmm + mn, avg = time_fn(batched_nmm, data) + print(f"batched_nmm: min={mn:.4f}s avg={avg:.4f}s") + + # greedy_nmm + mn, avg = time_fn(greedy_nmm, data) + print(f"greedy_nmm: min={mn:.4f}s avg={avg:.4f}s") + + # batched_greedy_nmm + mn, avg = time_fn(batched_greedy_nmm, data) + print(f"batched_greedy_nmm: min={mn:.4f}s avg={avg:.4f}s") + + +if __name__ == "__main__": + main() diff --git a/scripts/utils.py b/scripts/utils.py deleted file mode 100644 index be33b06e9..000000000 --- a/scripts/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -import shutil - - -def print_console_centered(text: str, fill_char="="): - """Print text centered in console with fill characters.""" - w, _ = shutil.get_terminal_size((80, 20)) - print(f" {text} ".center(w, fill_char)) diff --git a/tests/test_detectron2_model.py b/tests/test_detectron2_model.py index 9e6ef7992..d5a8f2c62 100644 --- a/tests/test_detectron2_model.py +++ b/tests/test_detectron2_model.py @@ -2,7 +2,7 @@ from sahi.prediction import ObjectPrediction from sahi.utils.cv import read_image from sahi.utils.detectron2 import Detectron2TestConstants -from sahi.utils.import_utils import get_package_info +from sahi.utils.package_utils import get_package_info MODEL_DEVICE = "cpu" CONFIDENCE_THRESHOLD = 0.5 @@ -11,7 +11,7 @@ # note that detectron2 binaries are available only for linux # TODO: This test is currently not running as torch version is pinned to 1.13 -torch_version = get_package_info("torch", verbose=False)[1] +torch_version = get_package_info("torch")[1] if "1.10." in torch_version: class TestDetectron2DetectionModel: