Skip to content

Commit 3d51e7b

Browse files
mkalimeriFBruzzesi
andauthored
feat: add handle_zero option in ZeroInflatedRegressor estimator (#714)
* Update test_zero_inflated_regressor.py Implementation of solution for issue 480: [FEATURE] Run with given regressor instead of raising warning in ZeroInflatedRegressor * Update test_zero_inflated_regressor.py Implementation of unitests for issue 480: [FEATURE] Run with given regressor instead of raising warning in ZeroInflatedRegressor * Update test_zero_inflated_regressor.py Fixed error in unit tests * Apply suggestions from code review Accepted edits on messages/comments as suggested Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> * Unit test update to match the updated ValueError output ValueError text for failure if handle_zero value is not one of ['ignore', 'error'] was updated. The relevant unittest had to be updated too * Unittest that asserts that if handle_zero='ignore' and all outputs are 0, no exception is thrown If all train set outputs are 0 and handle_zero = 'ignore', the regressor should fit the values as is and no exception should be thrown * move handle_zero to init --------- Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Co-authored-by: FBruzzesi <francesco.bruzzesi.93@gmail.com>
1 parent 1c66894 commit 3d51e7b

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

sklego/meta/zero_inflated_regressor.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class ZeroInflatedRegressor(BaseEstimator, RegressorMixin, MetaEstimatorMixin):
1313
1414
`ZeroInflatedRegressor` consists of a classifier and a regressor.
1515
16-
- The classifier's task is to find of if the target is zero or not.
17-
- The regressor's task is to output a (usually positive) prediction whenever the classifier indicates that the
16+
- The classifier's task is to find if the target is zero or not.
17+
- The regressor's task is to output a (usually positive) prediction whenever the classifier indicates that
1818
there should be a non-zero prediction.
1919
2020
The regressor is only trained on examples where the target is non-zero, which makes it easier for it to focus.
@@ -29,6 +29,11 @@ class ZeroInflatedRegressor(BaseEstimator, RegressorMixin, MetaEstimatorMixin):
2929
regressor : scikit-learn compatible regressor
3030
A regressor for predicting the target. Its prediction is only used if `classifier` says that the output is
3131
non-zero.
32+
handle_zero : Literal["error", "ignore"], default="error"
33+
How to behave in the case that all train set output consists of zero values only.
34+
35+
- `handle_zero = 'error'`: will raise a `ValueError` (default).
36+
- `handle_zero = 'ignore'`: will continue to train the regressor on the entire dataset.
3237
3338
Attributes
3439
----------
@@ -63,9 +68,10 @@ class ZeroInflatedRegressor(BaseEstimator, RegressorMixin, MetaEstimatorMixin):
6368

6469
_required_parameters = ["classifier", "regressor"]
6570

66-
def __init__(self, classifier, regressor) -> None:
71+
def __init__(self, classifier, regressor, handle_zero="error") -> None:
6772
self.classifier = classifier
6873
self.regressor = regressor
74+
self.handle_zero = handle_zero
6975

7076
def fit(self, X, y, sample_weight=None):
7177
"""Fit the underlying classifier and regressor using `X` and `y` as training data. The regressor is only trained
@@ -88,7 +94,9 @@ def fit(self, X, y, sample_weight=None):
8894
Raises
8995
------
9096
ValueError
91-
If `classifier` is not a classifier or `regressor` is not a regressor.
97+
If `classifier` is not a classifier
98+
If `regressor` is not a regressor
99+
If all train target entirely consists of zeros and `handle_zero="error"`
92100
"""
93101
X, y = check_X_y(X, y)
94102
self._check_n_features(X, reset=True)
@@ -98,6 +106,10 @@ def fit(self, X, y, sample_weight=None):
98106
)
99107
if not is_regressor(self.regressor):
100108
raise ValueError(f"`regressor` has to be a regressor. Received instance of {type(self.regressor)} instead.")
109+
if self.handle_zero not in {"ignore", "error"}:
110+
raise ValueError(
111+
f"`handle_zero` has to be one of {'ignore', 'error'}. Received '{self.handle_zero}' instead."
112+
)
101113

102114
sample_weight = _check_sample_weight(sample_weight, X)
103115
try:
@@ -112,9 +124,14 @@ def fit(self, X, y, sample_weight=None):
112124
logging.warning("Classifier ignores sample_weight.")
113125
self.classifier_.fit(X, y != 0)
114126

115-
non_zero_indices = np.where(y != 0)[0]
127+
indices_for_training = np.where(y != 0)[0] # these are the non-zero indices
128+
if (self.handle_zero == "ignore") & (
129+
indices_for_training.size == 0
130+
): # if we choose to ignore that all train set output is 0
131+
logging.warning("Regressor will be training on `y` consisting of zero values only.")
132+
indices_for_training = np.where(y == 0)[0] # use the whole train set
116133

117-
if non_zero_indices.size > 0:
134+
if indices_for_training.size > 0:
118135
try:
119136
check_is_fitted(self.regressor)
120137
self.regressor_ = self.regressor
@@ -123,20 +140,21 @@ def fit(self, X, y, sample_weight=None):
123140

124141
if "sample_weight" in signature(self.regressor_.fit).parameters:
125142
self.regressor_.fit(
126-
X[non_zero_indices],
127-
y[non_zero_indices],
128-
sample_weight=sample_weight[non_zero_indices] if sample_weight is not None else None,
143+
X[indices_for_training],
144+
y[indices_for_training],
145+
sample_weight=sample_weight[indices_for_training] if sample_weight is not None else None,
129146
)
130147
else:
131148
logging.warning("Regressor ignores sample_weight.")
132149
self.regressor_.fit(
133-
X[non_zero_indices],
134-
y[non_zero_indices],
150+
X[indices_for_training],
151+
y[indices_for_training],
135152
)
136153
else:
137154
raise ValueError(
138155
"""The predicted training labels are all zero, making the regressor obsolete. Change the classifier
139-
or use a plain regressor instead."""
156+
or use a plain regressor instead. Alternatively, you can choose to ignore that predicted labels are
157+
all zero by setting flag handle_zero = 'ignore'"""
140158
)
141159

142160
return self

tests/test_meta/test_zero_inflated_regressor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,24 @@ def test_zero_inflated_with_sample_weights_example(classifier, regressor, perfor
7171
assert zir_score > performance
7272

7373

74+
def test_zero_inflated_with_handle_zero_ignore_example():
75+
"""Test that if handle_zero='ignore' and all y are 0, no Exception will be thrown"""
76+
77+
np.random.seed(0)
78+
size = 1_000
79+
X = np.random.randn(size, 4)
80+
y = np.zeros(size) # all outputs are 0
81+
82+
zir = ZeroInflatedRegressor(
83+
classifier=ExtraTreesClassifier(max_depth=20, random_state=0, n_jobs=-1),
84+
regressor=ExtraTreesRegressor(max_depth=20, random_state=0, n_jobs=-1),
85+
handle_zero="ignore",
86+
).fit(X, y)
87+
88+
# The predicted values should all be 0
89+
assert (zir.predict(X) == np.zeros(size)).all()
90+
91+
7492
def test_wrong_estimators_exceptions():
7593
X = np.array([[0.0]])
7694
y = np.array([0.0])
@@ -83,6 +101,27 @@ def test_wrong_estimators_exceptions():
83101
zir = ZeroInflatedRegressor(ExtraTreesClassifier(), ExtraTreesClassifier())
84102
zir.fit(X, y)
85103

104+
with pytest.raises(
105+
ValueError, match="`handle_zero` has to be one of \('ignore', 'error'\). Received 'ignor' instead."
106+
):
107+
zir = ZeroInflatedRegressor(
108+
classifier=ExtraTreesClassifier(max_depth=20, random_state=0, n_jobs=-1),
109+
regressor=ExtraTreesRegressor(max_depth=20, random_state=0, n_jobs=-1),
110+
handle_zero="ignor",
111+
)
112+
zir.fit(X, y)
113+
114+
error_text = """The predicted training labels are all zero, making the regressor obsolete\. Change the classifier
115+
or use a plain regressor instead\. Alternatively, you can choose to ignore that predicted labels are
116+
all zero by setting flag handle_zero = 'ignore'"""
117+
118+
with pytest.raises(ValueError, match=error_text):
119+
zir = ZeroInflatedRegressor(
120+
classifier=ExtraTreesClassifier(max_depth=20, random_state=0, n_jobs=-1),
121+
regressor=ExtraTreesRegressor(max_depth=20, random_state=0, n_jobs=-1),
122+
)
123+
zir.fit(X, y) # default is handle_zero = 'error'
124+
86125

87126
def approx_lte(x, y):
88127
return ((x <= y) | np.isclose(x, y)).all()

0 commit comments

Comments
 (0)