Skip to content

Commit 85853d2

Browse files
Fix estimator type (#36)
* loosening numpy and pandas versions * fixes to build requirements * simplifying setup.py, logic in pyproject.toml * fixing _estimator_type flags * feat: Add scikit-learn estimator type tags Adds the `__sklearn_tags__` method to the `sklearn_sm` and `sklearn_selected` wrappers. This allows scikit-learn to correctly identify the estimator type (regressor or classifier) based on the statsmodels model. This change enables the use of scikit-learn's cross-validation and model selection tools with these wrappers. Tests have been added to verify that OLS and GLM Binomial models are correctly identified. * removing redundant setup.py * unused pkg_resources import
1 parent 707a2fe commit 85853d2

File tree

6 files changed

+110
-73
lines changed

6 files changed

+110
-73
lines changed

ISLP/models/generic_selector.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
import scipy as sp
2929

3030
from sklearn.metrics import get_scorer
31-
from sklearn.base import (clone, MetaEstimatorMixin)
31+
from sklearn.base import (clone,
32+
MetaEstimatorMixin,
33+
is_classifier,
34+
is_regressor)
3235
from sklearn.model_selection import cross_val_score
3336
from joblib import Parallel, delayed
3437

@@ -149,13 +152,13 @@ def __init__(self,
149152
self.scoring = scoring
150153

151154
if scoring is None:
152-
if self.est_._estimator_type == 'classifier':
155+
if is_classifier(self.est_):
153156
scoring = 'accuracy'
154-
elif self.est_._estimator_type == 'regressor':
157+
elif is_regressor(self.est_):
155158
scoring = 'r2'
156159
else:
157-
raise AttributeError('Estimator must '
158-
'be a Classifier or Regressor.')
160+
scoring = None
161+
159162
if isinstance(scoring, str):
160163
self.scorer = get_scorer(scoring)
161164
else:
@@ -166,7 +169,7 @@ def __init__(self,
166169
# don't mess with this unless testing
167170
self._TESTING_INTERRUPT_MODE = False
168171

169-
def fit(self, X, y, groups=None, **params):
172+
def fit(self, X, y, groups=None, **fit_params):
170173
"""Perform feature selection and learn model from training data.
171174
172175
Parameters
@@ -183,7 +186,7 @@ def fit(self, X, y, groups=None, **params):
183186
groups: array-like, with shape (n_samples,), optional
184187
Group labels for the samples used while splitting the dataset into
185188
train/test set. Passed to the fit method of the cross-validator.
186-
params: various, optional
189+
fit_params: various, optional
187190
Additional parameters that are being passed to the estimator.
188191
For example, `sample_weights=weights`.
189192
@@ -218,7 +221,7 @@ def fit(self, X, y, groups=None, **params):
218221
groups=groups,
219222
cv=self.cv,
220223
pre_dispatch=self.pre_dispatch,
221-
**params)
224+
**fit_params)
222225

223226
# keep a running track of the best state
224227

@@ -242,7 +245,7 @@ def fit(self, X, y, groups=None, **params):
242245
X,
243246
y,
244247
groups=groups,
245-
**params)
248+
**fit_params)
246249
iteration += 1
247250
cur, best_, self.finished_ = self.update_results_check(results_,
248251
self.path_,
@@ -287,7 +290,7 @@ def fit_transform(self,
287290
X,
288291
y,
289292
groups=None,
290-
**params):
293+
**fit_params):
291294
"""Fit to training data then reduce X to its most important features.
292295
293296
Parameters
@@ -304,7 +307,7 @@ def fit_transform(self,
304307
groups: array-like, with shape (n_samples,), optional
305308
Group labels for the samples used while splitting the dataset into
306309
train/test set. Passed to the fit method of the cross-validator.
307-
params: various, optional
310+
fit_params: various, optional
308311
Additional parameters that are being passed to the estimator.
309312
For example, `sample_weights=weights`.
310313
@@ -313,7 +316,7 @@ def fit_transform(self,
313316
Reduced feature subset of X, shape={n_samples, k_features}
314317
315318
"""
316-
self.fit(X, y, groups=groups, **params)
319+
self.fit(X, y, groups=groups, **fit_params)
317320
return self.transform(X)
318321

319322
def get_metric_dict(self, confidence_interval=0.95):
@@ -368,7 +371,7 @@ def _batch(self,
368371
X,
369372
y,
370373
groups=None,
371-
**params):
374+
**fit_params):
372375

373376
results = []
374377

@@ -388,7 +391,7 @@ def _batch(self,
388391
groups=groups,
389392
cv=self.cv,
390393
pre_dispatch=self.pre_dispatch,
391-
**params)
394+
**fit_params)
392395
for state in candidates)
393396

394397
for state, scores in work:
@@ -484,8 +487,11 @@ def _calc_score(estimator,
484487
groups=None,
485488
cv=None,
486489
pre_dispatch='2*n_jobs',
487-
**params):
490+
**fit_params):
488491

492+
if scorer is None:
493+
scorer = lambda estimator, X, y: estimator.score(X, y)
494+
489495
X_state = build_submodel(X, state)
490496

491497
if cv:
@@ -497,11 +503,11 @@ def _calc_score(estimator,
497503
scoring=scorer,
498504
n_jobs=1,
499505
pre_dispatch=pre_dispatch,
500-
params=params)
506+
fit_params=fit_params)
501507
else:
502508
estimator.fit(X_state,
503509
y,
504-
**params)
510+
**fit_params)
505511
scores = np.array([scorer(estimator,
506512
X_state,
507513
y)])

ISLP/models/sklearn_wrap.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,17 @@ def __init__(self,
4949
self.model_type = model_type
5050
self.model_spec = model_spec
5151
self.model_args = model_args
52-
52+
53+
def __sklearn_tags__(self):
54+
tags = super().__sklearn_tags__()
55+
if self.model_type == sm.OLS:
56+
tags.estimator_type = 'regressor'
57+
elif (issubclass(self.model_type, sm.GLM) and
58+
'family' in self.model_args and
59+
isinstance(self.model_args.get('family', None), sm.families.Binomial)):
60+
tags.estimator_type = 'classifier'
61+
return tags
62+
5363
def fit(self, X, y):
5464
"""
5565
Fit a statsmodel model
@@ -171,6 +181,9 @@ def __init__(self,
171181
self.cv = cv
172182
self.scoring = scoring
173183

184+
def __sklearn_tags__(self):
185+
tags = super().__sklearn_tags__()
186+
return tags
174187

175188
def fit(self, X, y):
176189
"""

ISLP/torch/imdb.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import torch
1313
from torch.utils.data import TensorDataset
1414
from scipy.sparse import load_npz
15-
from pkg_resources import resource_filename
1615
from pickle import load as load_pickle
1716
import urllib
1817

pyproject.toml

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "ISLP"
33
dependencies = ["numpy>=1.7.1",
44
"scipy>=0.9",
5-
"pandas>=0.20",
5+
"pandas>=1.5",
66
"lxml", # pandas needs this for html
77
"scikit-learn>=1.2",
88
"joblib",
@@ -15,7 +15,7 @@ dependencies = ["numpy>=1.7.1",
1515
]
1616
description = "Library for ISLP labs"
1717
readme = "README.md"
18-
requires-python = ">=3.9"
18+
requires-python = ">=3.10"
1919
license = {file = "LICENSE"}
2020
keywords = []
2121
authors = [
@@ -38,6 +38,23 @@ classifiers = ["Development Status :: 3 - Alpha",
3838
]
3939
dynamic = ["version"]
4040

41+
[tool.setuptools]
42+
packages = [
43+
"ISLP",
44+
"ISLP.models",
45+
"ISLP.bart",
46+
"ISLP.torch",
47+
"ISLP.data"
48+
]
49+
include-package-data = true
50+
51+
[tool.setuptools.package-data]
52+
ISLP = ["data/*.csv", "data/*.npy", "data/*.data"]
53+
54+
[tool.setuptools.dynamic]
55+
version = {attr = "ISLP.__version__"} # Assuming ISLP.__version__ holds your version
56+
57+
4158
[project.urls] # Optional
4259
"Homepage" = "https://github.com/intro-stat-learning/ISLP"
4360
"Bug Reports" = "https://github.com/intro-stat-learning/ISLP/issues"
@@ -51,8 +68,14 @@ doc = ['Sphinx>=3.0']
5168
[build-system]
5269
requires = ["setuptools>=42",
5370
"wheel",
54-
"versioneer[toml]",
55-
"Sphinx>=1.0"
71+
"Sphinx>=1.0",
72+
"numpy",
73+
"pandas",
74+
"scipy",
75+
"scikit-learn",
76+
"joblib",
77+
"statsmodels",
78+
"versioneer[toml]"
5679
]
5780
build-backend = "setuptools.build_meta"
5881

setup.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

tests/models/test_sklearn_wrap.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
2+
import numpy as np
3+
import pandas as pd
4+
import statsmodels.api as sm
5+
from sklearn.base import is_classifier, is_regressor
6+
import pytest
7+
8+
from ISLP.models.sklearn_wrap import sklearn_sm, sklearn_selected
9+
from ISLP.models.model_spec import ModelSpec
10+
from ISLP.models.strategy import min_max
11+
12+
@pytest.fixture
13+
def model_setup():
14+
X = pd.DataFrame({'X1': np.random.rand(10), 'X2': np.random.rand(10), 'X3': np.random.rand(10)})
15+
y = pd.Series(np.random.randint(0, 2, 10)) # For classifier
16+
model_spec_dummy = ModelSpec(['X1', 'X2', 'X3']).fit(X)
17+
min_max_strategy_dummy = min_max(model_spec_dummy, min_terms=1, max_terms=2)
18+
return X, y, model_spec_dummy, min_max_strategy_dummy
19+
20+
def test_OLS_is_regressor():
21+
model = sklearn_sm(sm.OLS)
22+
assert model.__sklearn_tags__().estimator_type == 'regressor'
23+
assert is_regressor(model)
24+
25+
def test_GLM_binomial_is_classifier():
26+
model = sklearn_sm(sm.GLM, model_args={'family': sm.families.Binomial()})
27+
assert model.__sklearn_tags__().estimator_type == 'classifier'
28+
assert is_classifier(model)
29+
30+
def test_GLM_binomial_probit_is_classifier():
31+
model = sklearn_sm(sm.GLM, model_args={'family': sm.families.Binomial(link=sm.families.links.Probit())})
32+
assert model.__sklearn_tags__().estimator_type == 'classifier'
33+
assert is_classifier(model)
34+
35+
36+
def test_selected_OLS_is_regressor(model_setup):
37+
X, y, model_spec_dummy, min_max_strategy_dummy = model_setup
38+
model = sklearn_selected(sm.OLS, strategy=min_max_strategy_dummy)
39+
assert model.__sklearn_tags__().estimator_type == 'regressor'
40+
assert is_regressor(model)
41+
42+
def test_selected_GLM_binomial_is_classifier(model_setup):
43+
X, y, model_spec_dummy, min_max_strategy_dummy = model_setup
44+
model = sklearn_selected(sm.GLM, strategy=min_max_strategy_dummy, model_args={'family': sm.families.Binomial()})
45+
assert model.__sklearn_tags__().estimator_type == 'classifier'
46+
assert is_classifier(model)

0 commit comments

Comments
 (0)