diff --git a/src/hyperactive/_registry/__init__.py b/src/hyperactive/registry/__init__.py similarity index 50% rename from src/hyperactive/_registry/__init__.py rename to src/hyperactive/registry/__init__.py index 722edf23..4216d6d0 100644 --- a/src/hyperactive/_registry/__init__.py +++ b/src/hyperactive/registry/__init__.py @@ -1,5 +1,5 @@ """Hyperactive registry.""" -from hyperactive._registry._lookup import all_objects +from hyperactive.registry._lookup import all_objects __all__ = ["all_objects"] diff --git a/src/hyperactive/_registry/_lookup.py b/src/hyperactive/registry/_lookup.py similarity index 94% rename from src/hyperactive/_registry/_lookup.py rename to src/hyperactive/registry/_lookup.py index 63f23b07..2410468f 100644 --- a/src/hyperactive/_registry/_lookup.py +++ b/src/hyperactive/registry/_lookup.py @@ -34,6 +34,16 @@ def all_objects( Not included are: the base classes themselves, classes defined in test modules. + To filter by object type, use the ``object_types`` parameter: + + * ``"optimizer"``: optimizers, e.g., the Hill Climbing optimizer, + or the Particle Swarm optimizer + * ``"experiment"``: experiments, e.g., the Ackley function, or an ``sklearn`` + cross-validation experiment + * if ``None``, no filter is applied and all objects are returned. + + To filter by tag, use the ``filter_tags`` parameter. + Parameters ---------- object_types: str, list of str, optional (default=None) @@ -137,7 +147,7 @@ def all_objects( Examples -------- - >>> from hyperactive._registry import all_objects + >>> from hyperactive.registry import all_objects >>> # return a complete list of objects as pd.Dataframe >>> all_objects(as_dataframe=True) # doctest: +SKIP diff --git a/src/hyperactive/registry/tests/__init__.py b/src/hyperactive/registry/tests/__init__.py new file mode 100644 index 00000000..96b5c8c8 --- /dev/null +++ b/src/hyperactive/registry/tests/__init__.py @@ -0,0 +1 @@ +"""Registry tests.""" diff --git a/src/hyperactive/registry/tests/test_lookup.py b/src/hyperactive/registry/tests/test_lookup.py new file mode 100644 index 00000000..44cb6241 --- /dev/null +++ b/src/hyperactive/registry/tests/test_lookup.py @@ -0,0 +1,52 @@ +# copyright: hyperactive developers, BSD-3-Clause License (see LICENSE file) +"""Testing of registry lookup functionality.""" + +# based on the sktime module of same name + +__author__ = ["fkiraly"] + +import pytest + +from hyperactive.registry import all_objects + +object_types = ["optimizer", "experiment"] + + +def _to_list(obj): + """Put obj in list if it is not a list.""" + if not isinstance(obj, list): + return [obj] + else: + return obj.copy() + + +@pytest.mark.parametrize("return_names", [True, False]) +@pytest.mark.parametrize("object_type", object_types) +def test_all_objects_by_scitype(object_type, return_names): + """Check that all_objects return argument has correct type.""" + objects = all_objects( + object_types=object_type, + return_names=return_names, + ) + + assert isinstance(objects, list) + # there should be at least one object returned + assert len(objects) > 0 + + # checks return type specification (see docstring) + for obj in objects: + if return_names: + assert isinstance(obj, tuple) and len(obj) == 2 + name = obj[0] + obj_cls = obj[1] + assert isinstance(name, str) + assert hasattr(obj_cls, "__name__") + assert name == obj_cls.__name__ + else: + obj_cls = obj + + assert hasattr(obj_cls, "get_tags") + type_from_obj = obj_cls.get_class_tag("object_type") + type_from_obj = _to_list(type_from_obj) + + assert object_type in type_from_obj diff --git a/src/hyperactive/tests/test_all_objects.py b/src/hyperactive/tests/test_all_objects.py index 986c0086..606ff43a 100644 --- a/src/hyperactive/tests/test_all_objects.py +++ b/src/hyperactive/tests/test_all_objects.py @@ -7,7 +7,7 @@ from skbase.testing import QuickTester as _QuickTester from skbase.testing import TestAllObjects as _TestAllObjects -from hyperactive._registry import all_objects +from hyperactive.registry import all_objects from hyperactive.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS from hyperactive.tests._doctest import run_doctest