-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathglobal_learner.py
More file actions
139 lines (108 loc) · 4.59 KB
/
global_learner.py
File metadata and controls
139 lines (108 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from sklearn import __version__ as sklearn_version
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, clone, is_classifier, is_regressor
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import _check_sample_weight, check_is_fitted
def parse_version(version):
return tuple(map(int, version.split(".")[:2]))
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
sklearn_supports_validation = parse_version(sklearn_version) >= (1, 6)
if sklearn_supports_validation:
from sklearn.utils.validation import validate_data
class GlobalRegressor(RegressorMixin, BaseEstimator):
"""
A global regressor that ignores the attribute `sample_weight` when being fit to ensure a global fit.
Parameters
----------
base_estimator: regressor implementing ``fit()`` and ``predict()``
Regressor that is used when ``fit()`` and ``predict()`` are being called.
"""
def __init__(self, base_estimator):
self.base_estimator = base_estimator
def fit(self, X, y, sample_weight=None):
"""
Fits the regressor provided in ``base_estimator``. Ignores ``sample_weight``.
Parameters
----------
X: array-like of shape (n_samples, n_features)
Training data.
y: array-like of shape (n_samples,) or (n_samples, n_targets)
Target values.
sample_weight: array-like of shape (n_samples,).
Individual weights for each sample. Ignored.
"""
if not is_regressor(self.base_estimator):
raise ValueError(f"base_estimator must be a regressor. Got {self.base_estimator.__class__.__name__} instead.")
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
if sklearn_supports_validation:
X, y = validate_data(self, X, y)
else:
X, y = self._validate_data(X, y)
_check_sample_weight(sample_weight, X)
self._fitted_learner = clone(self.base_estimator)
self._fitted_learner.fit(X, y)
return self
def predict(self, X):
"""
Predicts using the regressor provided in ``base_estimator``.
Parameters
----------
X: array-like of shape (n_samples, n_features)
Samples.
"""
check_is_fitted(self)
return self._fitted_learner.predict(X)
class GlobalClassifier(ClassifierMixin, BaseEstimator):
"""
A global classifier that ignores the attribute ``sample_weight`` when being fit to ensure a global fit.
Parameters
----------
base_estimator: classifier implementing ``fit()`` and ``predict_proba()``
Classifier that is used when ``fit()``, ``predict()`` and ``predict_proba()`` are being called.
"""
def __init__(self, base_estimator):
self.base_estimator = base_estimator
def fit(self, X, y, sample_weight=None):
"""
Fits the classifier provided in ``base_estimator``. Ignores ``sample_weight``.
Parameters
----------
X: array-like of shape (n_samples, n_features)
Training data.
y: array-like of shape (n_samples,) or (n_samples, n_targets)
Target classes.
sample_weight: array-like of shape (n_samples,).
Individual weights for each sample. Ignored.
"""
if not is_classifier(self.base_estimator):
raise ValueError(f"base_estimator must be a classifier. Got {self.base_estimator.__class__.__name__} instead.")
# TODO(0.11) can be removed if the sklearn dependency is bumped to 1.6.0
if sklearn_supports_validation:
X, y = validate_data(self, X, y)
else:
X, y = self._validate_data(X, y)
_check_sample_weight(sample_weight, X)
self.classes_ = unique_labels(y)
self._fitted_learner = clone(self.base_estimator)
self._fitted_learner.fit(X, y)
return self
def predict(self, X):
"""
Predicts using the classifier provided in ``base_estimator``.
Parameters
----------
X: array-like of shape (n_samples, n_features)
Samples.
"""
check_is_fitted(self)
return self._fitted_learner.predict(X)
def predict_proba(self, X):
"""
Probability estimates using the classifier provided in ``base_estimator``.
The returned estimates for all classes are ordered by the label of classes.
Parameters
----------
X: array-like of shape (n_samples, n_features)
Samples to be scored.
"""
check_is_fitted(self)
return self._fitted_learner.predict_proba(X)