Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) (+ the Migration Guide),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.1.0] -

### Changed

- (Potential Breaking): We now check explicitly, if classes follow the "rules of tpcp" and simply forward or set parameters
without modifying them. Before, we just "hoped" that this is the case.
If you had classes that did not follow these rules, you will now get an error.

### Added

- The `DatasetSplitter` now auto selects a proper splitter based on inputs and attempts to validate if the passed
splitter supports the features needed.


## [2.0.0] - 2024-10-24

### Added
Expand Down
65 changes: 61 additions & 4 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def create_test_class(

# Set the signature to conform to the expected conventions
sig = signature(test_class.__init__)
sig = sig.replace(parameters=(Parameter(k, Parameter.KEYWORD_ONLY, default=v) for k, v in params.items()))
sig = sig.replace(
parameters=(
Parameter("self", Parameter.POSITIONAL_OR_KEYWORD),
*(Parameter(k, Parameter.KEYWORD_ONLY, default=v) for k, v in user_set_params.items()),
)
)
test_class.__init__.__signature__ = sig
class_dict = {**class_dict, "__init__": test_class.__init__}
# Recreate the class with the correct init
Expand Down Expand Up @@ -116,7 +121,7 @@ def test_get_results(example_test_class_after_action):
def test_get_parameter(example_test_class_after_action):
instance, test_parameters = example_test_class_after_action

assert instance.get_params() == test_parameters["params"]
assert instance.get_params() == {**test_parameters["params"], **test_parameters["private_params"]}


def test_get_action_params(example_test_class_after_action):
Expand Down Expand Up @@ -589,12 +594,12 @@ def __init__(self, foo=cf("foo")):
self.foo = foo

class Bar(Foo):
def __init__(self, foo=cf("foo"), bar=cf("bar")):
def __init__(self, foo=cf("foo2"), bar="bar"):
super().__init__(foo)
self.bar = bar

bar = Bar()
assert bar.get_params() == {"foo": "foo", "bar": "bar"}
assert bar.get_params() == {"foo": "foo2", "bar": "bar"}


def test_validate_all_parent_params_implemented():
Expand Down Expand Up @@ -630,3 +635,55 @@ def __init__(self, a, b):
Child(2, 1).get_params()

assert "Parent" in str(e.value)


def test_validate_not_setting_all_parameters():
class Test(Algorithm):
def __init__(self, a, b):
self.a = a

with pytest.raises(RuntimeError) as e:
Test(a=1, b=2)

assert "`b`" in str(e.value)


def test_validate_modifing_parameters():
class Test(Algorithm):
def __init__(self, a):
self.a = a + 1

with pytest.raises(RuntimeError) as e:
Test(a=1)

assert "`a`" in str(e.value)


def test_validation_triggered_with_child_class():
class Parent(Algorithm):
def __init__(self, a):
self.a = a + 1

class Child(Parent):
def __init__(self, a, b):
self.b = b
super().__init__(a)

with pytest.raises(RuntimeError) as e:
Child(a=1, b=2)

assert "`a`" in str(e.value)


def test_validation_triggered_with_child_class_without_init():
class Parent(Algorithm):
def __init__(self, a):
self.a = a + 1

class Child(Parent):
pass

with pytest.raises(RuntimeError) as e:
Child(a=1)

assert "`a`" in str(e.value)
40 changes: 34 additions & 6 deletions tests/test_pipelines/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold

from tests.test_pipelines.conftest import (
DummyDataset,
Expand Down Expand Up @@ -309,13 +309,17 @@ def test_normal_k_fold(self):
assert train_expected.tolist() == train.tolist()
assert test_expected.tolist() == test.tolist()

def test_normal_k_fold_with_groupby_ignored(self):
@pytest.mark.parametrize("provide", ["groupby", "stratify"])
def test_normal_k_fold_with_groupby_and_stratified_ignored(self, provide):
ds = DummyGroupedDataset()
splitter = DatasetSplitter(base_splitter=KFold(n_splits=5), groupby="v1")
# This should be identical to just calling the splitter directly
splits_expected = list(KFold(n_splits=5).split(ds))
paras = {provide: "v1"}
expected_string = "grouping" if provide == "groupby" else "stratification"
splitter = DatasetSplitter(base_splitter=KFold(n_splits=5), **paras)
# This should be identical to just calling the splitter directly, but should give a warning
with pytest.warns(UserWarning, match=f"sklearn splitters that do support {expected_string}."):
splits = list(splitter.split(ds))

splits = list(splitter.split(ds))
splits_expected = list(KFold(n_splits=5).split(ds))

for (train_expected, test_expected), (train, test) in zip(splits_expected, splits):
assert train_expected.tolist() == train.tolist()
Expand Down Expand Up @@ -344,3 +348,27 @@ def test_normal_stratified_k_fold(self):
for (train_expected, test_expected), (train, test) in zip(splits_expected, splits):
assert train_expected.tolist() == train.tolist()
assert test_expected.tolist() == test.tolist()

def test_auto_selection_group_k_fold(self):
splitter = DatasetSplitter(base_splitter=None, groupby="v1")
inner_splitter = splitter._get_splitter()
assert isinstance(inner_splitter, GroupKFold)
assert inner_splitter.n_splits == 5

def test_auto_selection_stratified_k_fold(self):
splitter = DatasetSplitter(base_splitter=None, stratify="v1")
inner_splitter = splitter._get_splitter()
assert isinstance(inner_splitter, StratifiedKFold)
assert inner_splitter.n_splits == 5

def test_auto_selection_k_fold(self):
splitter = DatasetSplitter(base_splitter=None)
inner_splitter = splitter._get_splitter()
assert isinstance(inner_splitter, KFold)
assert inner_splitter.n_splits == 5

def test_auto_selection_group_k_fold_with_stratified(self):
splitter = DatasetSplitter(base_splitter=None, groupby="v1", stratify="v2")
inner_splitter = splitter._get_splitter()
assert isinstance(inner_splitter, StratifiedGroupKFold)
assert inner_splitter.n_splits == 5
68 changes: 62 additions & 6 deletions tpcp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ def _get_init_defaults(cls: type[_BaseTpcpObject]) -> dict[str, inspect.Paramete
return defaults


_init_implementation_hint = (
"All classes in using tpcp must set all parameters passed to the init as parameter "
"without any modification.\n\n"
"A typical `__init__` method should look like this:\n\n"
">>> def __init__(self, para_1, para_2):\n"
">>> self.para_1 = para_1\n"
">>> self.para_2 = para_2\n"
"\nAny additional logic to validate parameters or set complicated defaults should be "
"done at the start of the action method, not in the init.\n"
"The reason for this is that you can modify parameters after the init has run. "
"In case relevant logic is contained in the `__init__` method, this logic would be "
"skipped, if parameters are modified after the init using `set_params`."
)


def _is_latest_init_in_mro(cls, instance):
"""Check if cls.__init__ is the first __init__ in the MRO of instance."""
mro = type(instance).__mro__

# Find the first class in MRO that defines its own __init__
first_with_init = next((c for c in mro if "__init__" in c.__dict__), None)

return first_with_init is cls


def _replace_defaults_wrapper(
cls: type[_BaseTpcpObjectT], old_init: Callable[Concatenate[_BaseTpcpObjectT, P], T]
) -> Callable[Concatenate[_BaseTpcpObjectT, P], T]:
Expand All @@ -116,18 +141,48 @@ def _replace_defaults_wrapper(
# super().__init__ is called before all parameters of the child object are set.
# This way, the param checks only concern the parameters of the current class.
params = get_param_names(cls)
defaults = {k: v.default for k, v in _get_init_defaults(cls).items()}

@wraps(old_init)
def new_init(self: _BaseTpcpObjectT, *args: P.args, **kwargs: P.kwargs) -> None:
# call the old init.
old_init(self, *args, **kwargs)
# We check if we have been run via super or if the init was called directly.
# This is required, as we can perform all the checks that follow only after the full "callstacK" of inits
# has been run.
# This means, we only want to run the checks, if the init was called directly, i.e. not via super.
if not _is_latest_init_in_mro(cls, self):
return

# Check if any of the initial values has a "default parameter flag".
# If yes we replace it with a clone (in case of a tpcp object) or a deepcopy in case of other objects.
# After the old init ran, we can check that the parameters have been set correctly.
# Tpcp classes should not modify any parameters in their init and set them to exactly the values using
# attributes with the same name.
# We get all params passed to the init, by manually binding them to the signature of the old init.
# We check that all parameters are set and the values are identical
#
# At the same time, we check if any parameter is wrapped by a base factory.
# And replace the value accordingly.
# This is handled by the factory `get_value` method.
passed_params = {**defaults, **inspect.signature(old_init).bind(self, *args, **kwargs).arguments}
sentinel = object() # We use a unique object to check if the value was set. We can not use NOTHING here, as
# NOTHING can be a valid input parameter value.
for p in params:
if isinstance(val := getattr(self, p), BaseFactory):
setattr(self, p, val.get_value())
passed_value = passed_params[p]
set_value = getattr(self, p, sentinel)
if set_value is sentinel:
raise RuntimeError(
f"The class `{cls.__name__}` did not set the parameter `{p}` in its init.\n\n"
f"{_init_implementation_hint}"
)
if set_value is not passed_value:
raise RuntimeError(
f"The class `{cls.__name__}` or one of its parent classes modified the parameter `{p}` in its "
"init. "
f"The value available as {cls.__name__}.{p} is not identical to the value passed to the init with "
f"the same name.\n\n{_init_implementation_hint}"
)
if isinstance(set_value, BaseFactory):
setattr(self, p, set_value.get_value())

# This is just for introspection, in case we want to know if we have a modified init.
new_init.__tpcp_wrapped__ = True
Expand Down Expand Up @@ -195,7 +250,7 @@ def _retry_eval_with_missing_locals(
return val


def _custom_get_type_hints(cls: type[_BaseTpcpObject]) -> dict[str, Any]:
def _custom_get_type_hints(cls: type) -> dict[str, Any]:
"""Extract type hints while avoiding issues with forward references.

We automatically skip all douple-underscore methods.
Expand Down Expand Up @@ -627,7 +682,8 @@ def _has_dangerous_mutable_default(fields: dict[str, inspect.Parameter], cls: ty
"\n"
"Note, that we do not check for all cases of mutable objects. "
f"At the moment, we check only for {_get_dangerous_mutable_types()}. "
"To learn more about this topic, check TODO: LINK."
"To learn more about this topic, check "
"https://tpcp.readthedocs.io/en/latest/guides/general_concepts.html#mutable-defaults ."
)


Expand Down
66 changes: 58 additions & 8 deletions tpcp/validate/_cross_val_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numbers
import warnings
from collections.abc import Iterator
from typing import Optional, Union

from sklearn.model_selection import BaseCrossValidator, check_cv
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedGroupKFold, StratifiedKFold, check_cv

from tpcp import BaseTpcpObject, Dataset

Expand All @@ -12,11 +14,11 @@ class DatasetSplitter(BaseTpcpObject):
This wrapper can be used instead of a sklearn-style splitter with all methods that support a ``cv`` parameter.
Whenever you want to do complicated cv-logic (like grouping or stratification's), this wrapper is the way to go.

.. warning:: We don't validate if the selected ``base_splitter`` does anything useful with the provided
``groupby`` and ``stratify`` information.
This wrapper just ensures, that the information is correctly extracted from the dataset and passed to the
``split`` method of the ``base_splitter``.
So if you are using a normal ``KFold`` splitter, the ``groupby`` and ``stratify`` arguments will have no effect.
You can either select your own base splitter, or we will select from KFold, StratifiedKFold, GroupKFold, or
StratifiedGroupKFold, depending on the provided ``groupby`` and ``stratify`` parameters.

.. warning:: If you use a custom splitter, that does not support grouping or stratification, these parameters might
be silently ignored.

Parameters
----------
Expand All @@ -38,7 +40,11 @@ class DatasetSplitter(BaseTpcpObject):
This will passed to the base splitter as the ``y`` parameter, acting as "mock" target labels, as sklearn only
support stratification on classification outcome targets.
It is up to the base splitter to decide what to do with the generated labels.

ignore_potentially_invalid_splitter_warning
We are trying to detect if the provided splitter supports grouping and stratification.
If they are not supported, but you provided groupby or stratify columns, we will warn you.
Note, that this warning is not a perfect check, as it is not possible to detect all cases.
If you know what you are doing, and you want to disable this warning, set this parameter to True.
"""

def __init__(
Expand All @@ -47,13 +53,57 @@ def __init__(
*,
groupby: Optional[Union[str, list[str]]] = None,
stratify: Optional[Union[str, list[str]]] = None,
ignore_potentially_invalid_splitter_warning: bool = False,
):
self.base_splitter = base_splitter
self.stratify = stratify
self.groupby = groupby
self.ignore_potentially_invalid_splitter_warning = ignore_potentially_invalid_splitter_warning

def _get_splitter(self):
return check_cv(self.base_splitter, y=None, classifier=True)
cv = self.base_splitter
cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
if self.groupby is not None and self.stratify is not None:
cv = StratifiedGroupKFold(n_splits=cv)
elif self.groupby is not None:
cv = GroupKFold(n_splits=cv)
elif self.stratify is not None:
cv = StratifiedKFold(n_splits=cv)
cv = check_cv(cv, y=None, classifier=True)

if self.ignore_potentially_invalid_splitter_warning:
return cv

# The checks below might be redundant, but it makes the code structure easier to follow.
msg = None
if self.groupby and "Group" not in cv.__class__.__name__:
msg = (
"You specified groupby columns for the splitter, but it looks like you did not select any of the "
"typical sklearn splitters that do support grouping. "
"Splitters that don't support grouping will silently ignore the grouping information.",
)
if self.stratify and "Stratified" not in cv.__class__.__name__:
msg = (
"You specified stratify columns for the splitter, but it looks like you did not select any of the "
"typical sklearn splitters that do support stratification. "
"Splitters that don't support stratification will silently ignore the stratification information.",
)
if msg is not None:
warnings.warn(
(
f"{msg}"
"\nTo fix this issue pass a splitter that supports the required functionality as `base_splitter`."
"For a list of available splitters see "
"https://scikit-learn.org/stable/api/sklearn.model_selection.html "
"\nIf you provided a custom splitter, and you know what you are doing, you can disable this "
"warning, by setting the `ignore_potentially_invalid_splitter_warning=True` when creating the "
"DatasetSplitter object."
),
UserWarning,
stacklevel=2,
)
return cv

def _get_labels(self, dataset: Dataset, labels: Union[None, str, list[str]]):
if labels:
Expand Down
Loading