-
Notifications
You must be signed in to change notification settings - Fork 63
[ENH] extend sktime ForecastingOptCV with broadcasting options and returned parameters
#205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0341319
74ed4a0
707f3a2
89d7e56
4fb6fa2
ea9d52a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,20 @@ | ||
| # copyright: hyperactive developers, MIT License (see LICENSE file) | ||
|
|
||
| import time | ||
|
|
||
| import numpy as np | ||
| from skbase.utils.dependencies import _check_soft_dependencies | ||
|
|
||
| if _check_soft_dependencies("sktime", severity="none"): | ||
| _HAS_SKTIME = _check_soft_dependencies("sktime", severity="none") | ||
|
|
||
| if _HAS_SKTIME: | ||
| from sktime.datatypes import mtype_to_scitype | ||
| from sktime.forecasting.base._delegate import _DelegatedForecaster | ||
| else: | ||
| from skbase.base import BaseEstimator as _DelegatedForecaster | ||
|
|
||
| mtype_to_scitype = None | ||
|
|
||
| from hyperactive.experiment.integrations.sktime_forecasting import ( | ||
| SktimeForecastingExperiment, | ||
| ) | ||
|
|
@@ -151,6 +158,15 @@ class ForecastingOptCV(_DelegatedForecaster): | |
| - "logger_name": str, default="ray"; name of the logger to use. | ||
| - "mute_warnings": bool, default=False; if True, suppresses warnings | ||
|
|
||
| tune_by_instance : bool, optional (default=False) | ||
| Whether to tune parameters separately for each time series instance when | ||
| panel or hierarchical data is passed. Mirrors ``ForecastingGridSearchCV`` | ||
| semantics by delegating broadcasting to sktime's vectorization logic. | ||
| tune_by_variable : bool, optional (default=False) | ||
| Whether to tune parameters per variable for strictly multivariate series. | ||
| When enabled, only univariate targets are accepted and internal | ||
| broadcasting is handled by sktime. | ||
|
|
||
| Example | ||
| ------- | ||
| Any available tuning engine from hyperactive can be used, for example: | ||
|
|
@@ -215,6 +231,8 @@ def __init__( | |
| cv_X=None, | ||
| backend=None, | ||
| backend_params=None, | ||
| tune_by_instance=False, | ||
| tune_by_variable=False, | ||
| ): | ||
| self.forecaster = forecaster | ||
| self.optimizer = optimizer | ||
|
|
@@ -227,8 +245,20 @@ def __init__( | |
| self.cv_X = cv_X | ||
| self.backend = backend | ||
| self.backend_params = backend_params | ||
| self.tune_by_instance = tune_by_instance | ||
| self.tune_by_variable = tune_by_variable | ||
| super().__init__() | ||
|
|
||
| if _HAS_SKTIME: | ||
| self._set_delegated_tags(delegate=self.forecaster) | ||
| tags_to_clone = ["y_inner_mtype", "X_inner_mtype"] | ||
| self.clone_tags(self.forecaster, tags_to_clone) | ||
| self._extend_to_all_scitypes("y_inner_mtype") | ||
| self._extend_to_all_scitypes("X_inner_mtype") | ||
|
|
||
| if self.tune_by_variable: | ||
| self.set_tags(**{"scitype:y": "univariate"}) | ||
|
|
||
| def _fit(self, y, X, fh): | ||
| """Fit to training data. | ||
|
|
||
|
|
@@ -250,6 +280,16 @@ def _fit(self, y, X, fh): | |
| forecaster = self.forecaster.clone() | ||
|
|
||
| scoring = check_scoring(self.scoring, obj=self) | ||
| self.scorer_ = scoring | ||
| get_n_splits = getattr(self.cv, "get_n_splits", None) | ||
| if callable(get_n_splits): | ||
| try: | ||
| self.n_splits_ = get_n_splits(y) | ||
| except TypeError: | ||
| # fallback for splitters that expect no args | ||
| self.n_splits_ = get_n_splits() | ||
| else: | ||
| self.n_splits_ = None | ||
| # scoring_name = f"test_{scoring.name}" | ||
|
|
||
| experiment = SktimeForecastingExperiment( | ||
|
|
@@ -270,14 +310,54 @@ def _fit(self, y, X, fh): | |
| best_params = optimizer.solve() | ||
|
|
||
| self.best_params_ = best_params | ||
| self.best_index_ = getattr(optimizer, "best_index_", None) | ||
| raw_best_score, best_metadata = experiment.evaluate(best_params) | ||
| self.best_score_ = float(raw_best_score) | ||
| results_table = best_metadata.get("results") if best_metadata else None | ||
| if results_table is not None: | ||
| try: | ||
| self.cv_results_ = results_table.copy() | ||
| except AttributeError: | ||
| self.cv_results_ = results_table | ||
| else: | ||
| self.cv_results_ = None | ||
| self.best_forecaster_ = forecaster.set_params(**best_params) | ||
|
|
||
| # Refit model with best parameters. | ||
| if self.refit: | ||
| refit_start = time.perf_counter() | ||
| self.best_forecaster_.fit(y=y, X=X, fh=fh) | ||
| self.refit_time_ = time.perf_counter() - refit_start | ||
| else: | ||
| self.refit_time_ = 0.0 | ||
|
|
||
| return self | ||
|
|
||
| def _extend_to_all_scitypes(self, tagname): | ||
| """Ensure mtypes for all scitypes are present in tag ``tagname``.""" | ||
| if not _HAS_SKTIME: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed since |
||
| return | ||
|
|
||
| tagval = self.get_tag(tagname) | ||
| if not isinstance(tagval, list): | ||
| tagval = [tagval] | ||
| scitypes = mtype_to_scitype(tagval, return_unique=True) | ||
|
|
||
| if "Series" not in scitypes: | ||
| tagval = tagval + ["pd.DataFrame"] | ||
| elif "pd.Series" in tagval and "pd.DataFrame" not in tagval: | ||
| tagval = ["pd.DataFrame"] + tagval | ||
|
|
||
| if "Panel" not in scitypes: | ||
| tagval = tagval + ["pd-multiindex"] | ||
| if "Hierarchical" not in scitypes: | ||
| tagval = tagval + ["pd_multiindex_hier"] | ||
|
|
||
| if self.tune_by_instance: | ||
| tagval = [x for x in tagval if mtype_to_scitype(x) == "Series"] | ||
|
|
||
| self.set_tags(**{tagname: tagval}) | ||
|
|
||
| def _predict(self, fh, X): | ||
| """Forecast time series at future horizon. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,18 @@ | ||
| """Integration tests for sktime tuners.""" | ||
| # copyright: hyperactive developers, MIT License (see LICENSE file) | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| from skbase.utils.dependencies import _check_soft_dependencies | ||
|
|
||
| if _check_soft_dependencies("sktime", severity="none"): | ||
| from sktime.datasets import load_airline | ||
| from sktime.forecasting.naive import NaiveForecaster | ||
| from sktime.performance_metrics.forecasting import MeanAbsolutePercentageError | ||
| from sktime.split import ExpandingWindowSplitter | ||
|
|
||
| from hyperactive.integrations.sktime import ForecastingOptCV, TSCOptCV | ||
| from hyperactive.opt import GridSearchSk | ||
|
|
||
| EST_TO_TEST = [ForecastingOptCV, TSCOptCV] | ||
| else: | ||
|
|
@@ -20,3 +27,47 @@ def test_sktime_estimator(estimator): | |
| check_estimator(estimator, raise_exceptions=True) | ||
| # The above line collects all API conformance tests in sktime and runs them. | ||
| # It will raise an error if the estimator is not API conformant. | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not EST_TO_TEST, reason="sktime not installed") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would directly condition on |
||
| def test_forecasting_opt_cv_sets_attributes(): | ||
| """ForecastingOptCV exposes useful attributes after fitting.""" | ||
| fh = [1, 2] | ||
| y = load_airline().iloc[:36] | ||
| cv = ExpandingWindowSplitter(initial_window=24, step_length=6, fh=fh) | ||
| optimizer = GridSearchSk(param_grid={"strategy": ["last", "mean"]}) | ||
|
|
||
| tuner = ForecastingOptCV( | ||
| forecaster=NaiveForecaster(), | ||
| optimizer=optimizer, | ||
| cv=cv, | ||
| scoring=MeanAbsolutePercentageError(symmetric=True), | ||
| backend="None", | ||
| ) | ||
|
|
||
| tuner.fit(y=y, fh=fh) | ||
|
|
||
| assert tuner.scorer_.name == "MeanAbsolutePercentageError" | ||
| assert tuner.n_splits_ == cv.get_n_splits(y) | ||
| assert tuner.refit_time_ >= 0 | ||
|
|
||
| metric_col = "test_" + tuner.scorer_.name | ||
| assert metric_col in tuner.cv_results_.columns | ||
| assert np.isclose(tuner.best_score_, tuner.cv_results_[metric_col].mean()) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not EST_TO_TEST, reason="sktime not installed") | ||
| def test_forecasting_opt_cv_tune_by_flags(): | ||
| """Tune-by flags should adjust estimator tags.""" | ||
| tuner = ForecastingOptCV( | ||
| forecaster=NaiveForecaster(), | ||
| optimizer=GridSearchSk(param_grid={"strategy": ["last"]}), | ||
| cv=ExpandingWindowSplitter(initial_window=5, step_length=1, fh=[1]), | ||
| tune_by_instance=True, | ||
| tune_by_variable=True, | ||
| ) | ||
|
|
||
| assert tuner.get_tag("scitype:y") == "univariate" | ||
| y_mtypes = tuner.get_tag("y_inner_mtype") | ||
| assert "pd-multiindex" not in y_mtypes | ||
| assert "pd_multiindex_hier" not in y_mtypes | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move the import to
_extend_to_all_scitypes.