Skip to content

Commit 4c5f94b

Browse files
authored
[MNT] Isolate scikit-learn dependency with checks (#674)
`scikit-learn` is a soft dependency - this PR fully isolates `scikit-learn` with dependency checks. It is not removed from the core dependency set, even though the dependency is entirely localized in the `CovarianceShrinkage` object, since it seems to be called frequently from the rest of the code base.
1 parent e88bf31 commit 4c5f94b

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

pypfopt/risk_models.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import numpy as np
2727
import pandas as pd
28+
from skbase.utils.dependencies import _check_soft_dependencies
2829

2930
from .expected_returns import returns_from_prices
3031

@@ -298,11 +299,14 @@ def min_cov_determinant(
298299
warnings.warn("data is not in a dataframe", RuntimeWarning)
299300
prices = pd.DataFrame(prices)
300301

301-
# Extra dependency
302-
try:
303-
import sklearn.covariance
304-
except (ModuleNotFoundError, ImportError):
305-
raise ImportError("Please install scikit-learn via pip or poetry")
302+
if not _check_soft_dependencies(["scikit-learn"], severity="none"):
303+
raise ImportError(
304+
"scikit-learn is required to use min_cov_determinant. "
305+
"Please ensure that scikit-learn is installed in your environment,"
306+
" e.g via pip install scikit-learn"
307+
)
308+
309+
from sklearn.covariance import fast_mcd
306310

307311
assets = prices.columns
308312

@@ -312,7 +316,7 @@ def min_cov_determinant(
312316
X = returns_from_prices(prices, log_returns)
313317
# X = np.nan_to_num(X.values)
314318
X = X.dropna().values
315-
raw_cov_array = sklearn.covariance.fast_mcd(X, random_state=random_state)[1]
319+
raw_cov_array = fast_mcd(X, random_state=random_state)[1]
316320
cov = pd.DataFrame(raw_cov_array, index=assets, columns=assets) * frequency
317321
return fix_nonpositive_semidefinite(cov, kwargs.get("fix_method", "spectral"))
318322

@@ -379,13 +383,16 @@ def __init__(self, prices, returns_data=False, frequency=252, log_returns=False)
379383
:param log_returns: whether to compute using log returns
380384
:type log_returns: bool, defaults to False
381385
"""
382-
# Optional import
383-
try:
384-
from sklearn import covariance
386+
if not _check_soft_dependencies(["scikit-learn"], severity="none"):
387+
raise ImportError(
388+
"scikit-learn is required to use CovarianceShrinkage. "
389+
"Please ensure that scikit-learn is installed in your environment,"
390+
" e.g via pip install scikit-learn"
391+
)
392+
393+
from sklearn import covariance
385394

386-
self.covariance = covariance
387-
except (ModuleNotFoundError, ImportError): # pragma: no cover
388-
raise ImportError("Please install scikit-learn via pip or poetry")
395+
self.covariance = covariance
389396

390397
if not isinstance(prices, pd.DataFrame):
391398
warnings.warn("data is not in a dataframe", RuntimeWarning)

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ dependencies = [
3535
"cvxpy>=1.1.19",
3636
"numpy>=1.26.0",
3737
"pandas>=0.19",
38+
"scikit-base<0.14.0",
3839
"scikit-learn>=0.24.1",
3940
"scipy>=1.3.0",
40-
"scikit-base<0.14.0",
4141
]
4242

4343
[project.optional-dependencies]
@@ -54,7 +54,6 @@ dependencies = [
5454
all_extras = [
5555
"matplotlib>=3.2.0",
5656
"plotly>=5.0.0,<6",
57-
"scikit-learn>=0.24.1",
5857
"ecos>=2.0.14,<2.1",
5958
"plotly>=5.0.0,<7",
6059
"cvxopt; python_version < '3.14'",

0 commit comments

Comments
 (0)