Skip to content

Commit 1aa6807

Browse files
add fill_value to RelativeFeatures for when dividing by zero (#663)
* add fill_value param to init() * create test_error_if_fill_value_not_permitted * fix bug * add fill_value param to relevant tests * create test_transformer_incl_fill_values_when_dividing_by_zero(). need to add logic to _div() * fix bug in test_transformer_fill_values_when_division_by_zero() * update docstring * fix bug * update unit tests * rebase branch * fix style bug * do not accept a string for . change tests accordingly. * delete fill_value when default value, None * update unit tests * fix style errors * sorted and tidied imports * reformat and added fill value to all div methods * expanded tests * fixed bug when using replace * add test to match error message --------- Co-authored-by: Morgan-Sell <[email protected]>
1 parent 2c71e8c commit 1aa6807

File tree

2 files changed

+118
-29
lines changed

2 files changed

+118
-29
lines changed

feature_engine/creation/relative_features.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ class RelativeFeatures(BaseCreation):
7777
one or more of the following strings: 'add', 'mul','sub', 'div', truediv,
7878
'floordiv', 'mod', 'pow'.
7979
80+
fill_value: int, float, default=None
81+
When dividing by zero, this value is used in place of infinity. If None,
82+
then an error will be raised when dividing by zero.
83+
8084
{missing_values}
8185
8286
{drop_original}
@@ -128,6 +132,7 @@ def __init__(
128132
variables: List[Union[str, int]],
129133
reference: List[Union[str, int]],
130134
func: List[str],
135+
fill_value: Union[int, float, None] = None,
131136
missing_values: str = "ignore",
132137
drop_original: bool = False,
133138
) -> None:
@@ -163,10 +168,16 @@ def __init__(
163168
"Supported functions are {}. ".format(", ".join(_PERMITTED_FUNCTIONS))
164169
)
165170

171+
if fill_value is not None and not isinstance(fill_value, (float, int)):
172+
raise ValueError(
173+
"fill_value must be a float, integer or None. "
174+
f"Got {fill_value} instead."
175+
)
166176
super().__init__(missing_values, drop_original)
167177
self.variables = variables
168178
self.reference = reference
169179
self.func = func
180+
self.fill_value = fill_value
170181

171182
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
172183
"""
@@ -213,17 +224,6 @@ def _sub(self, X):
213224
X[varname] = X[self.variables].sub(X[reference], axis=0)
214225
return X
215226

216-
def _div(self, X):
217-
for reference in self.reference:
218-
if (X[reference] == 0).any():
219-
raise ValueError(
220-
"Some of the reference variables contain 0 as values. Check and "
221-
"remove those before using this transformer."
222-
)
223-
varname = [f"{var}_div_{reference}" for var in self.variables]
224-
X[varname] = X[self.variables].div(X[reference], axis=0)
225-
return X
226-
227227
def _add(self, X):
228228
for reference in self.reference:
229229
varname = [f"{var}_add_{reference}" for var in self.variables]
@@ -236,38 +236,60 @@ def _mul(self, X):
236236
X[varname] = X[self.variables].mul(X[reference], axis=0)
237237
return X
238238

239-
def _truediv(self, X):
239+
def _div(self, X):
240+
for reference in self.reference:
241+
zeros_ix, contains_zero = self._find_zeroes_in_reference(X, reference)
240242

243+
if self.fill_value is None and contains_zero:
244+
self._raise_error_when_zero_in_denominator()
245+
246+
varname = [f"{var}_div_{reference}" for var in self.variables]
247+
X[varname] = X[self.variables].div(X[reference], axis=0)
248+
249+
if contains_zero:
250+
X.loc[zeros_ix, varname] = self.fill_value
251+
return X
252+
253+
def _truediv(self, X):
241254
for reference in self.reference:
242-
if (X[reference] == 0).any():
243-
raise ValueError(
244-
"Some of the reference variables contain 0 as values. Check and "
245-
"remove those before using this transformer."
246-
)
255+
zeros_ix, contains_zero = self._find_zeroes_in_reference(X, reference)
256+
257+
if self.fill_value is None and contains_zero:
258+
self._raise_error_when_zero_in_denominator()
259+
247260
varname = [f"{var}_truediv_{reference}" for var in self.variables]
248261
X[varname] = X[self.variables].truediv(X[reference], axis=0)
262+
263+
if contains_zero:
264+
X.loc[zeros_ix, varname] = self.fill_value
249265
return X
250266

251267
def _floordiv(self, X):
252268
for reference in self.reference:
253-
if (X[reference] == 0).any():
254-
raise ValueError(
255-
"Some of the reference variables contain 0 as values. Check and "
256-
"remove those before using this transformer."
257-
)
269+
zeros_ix, contains_zero = self._find_zeroes_in_reference(X, reference)
270+
271+
if self.fill_value is None and contains_zero:
272+
self._raise_error_when_zero_in_denominator()
273+
258274
varname = [f"{var}_floordiv_{reference}" for var in self.variables]
259275
X[varname] = X[self.variables].floordiv(X[reference], axis=0)
276+
277+
if contains_zero:
278+
X.loc[zeros_ix, varname] = self.fill_value
260279
return X
261280

262281
def _mod(self, X):
263282
for reference in self.reference:
264-
if (X[reference] == 0).any():
265-
raise ValueError(
266-
"Some of the reference variables contain 0 as values. Check and "
267-
"remove those before using this transformer."
268-
)
283+
zeros_ix, contains_zero = self._find_zeroes_in_reference(X, reference)
284+
285+
if self.fill_value is None and contains_zero:
286+
self._raise_error_when_zero_in_denominator()
287+
269288
varname = [f"{var}_mod_{reference}" for var in self.variables]
270289
X[varname] = X[self.variables].mod(X[reference], axis=0)
290+
291+
if contains_zero:
292+
X.loc[zeros_ix, varname] = self.fill_value
271293
return X
272294

273295
def _pow(self, X):
@@ -276,6 +298,18 @@ def _pow(self, X):
276298
X[varname] = X[self.variables].pow(X[reference], axis=0)
277299
return X
278300

301+
def _raise_error_when_zero_in_denominator(self):
302+
raise ValueError(
303+
"Some of the reference variables contain zeroes. Division by zero "
304+
"does not exist. Replace zeros before using this transformer for division "
305+
"or set `fill_value` to a number."
306+
)
307+
308+
def _find_zeroes_in_reference(self, X, var):
309+
zero_ix = X[var] == 0
310+
zero_bool = (zero_ix).any()
311+
return zero_ix, zero_bool
312+
279313
def _get_new_features_name(self) -> List:
280314
"""Return names of the created features."""
281315

tests/test_creation/test_relative_features.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ def test_error_if_func_not_supported(_func):
5353
)
5454

5555

56+
@pytest.mark.parametrize("_fill_value", [(2, 3.3), ["test"], "python"])
57+
def test_error_if_fill_value_not_permitted(_fill_value):
58+
with pytest.raises(ValueError):
59+
RelativeFeatures(
60+
variables=["Age"],
61+
reference=["Marks"],
62+
func=["sub", "div", "add", "mul"],
63+
fill_value=_fill_value,
64+
)
65+
66+
5667
def test_error_when_drop_original_not_bool():
5768
for drop_original in ["True", [True]]:
5869
with pytest.raises(ValueError):
@@ -318,7 +329,7 @@ def test_when_df_cols_are_integers(df_vartypes):
318329

319330

320331
@pytest.mark.parametrize("_func", [["div"], ["truediv"], ["floordiv"], ["mod"]])
321-
def test_error_when_division_by_zero(_func, df_vartypes):
332+
def test_error_when_division_by_zero_and_fill_value_is_none(_func, df_vartypes):
322333

323334
df_zero = df_vartypes.copy()
324335
df_zero.loc[1, "Marks"] = 0
@@ -329,9 +340,53 @@ def test_error_when_division_by_zero(_func, df_vartypes):
329340
func=_func,
330341
)
331342
transformer.fit(df_vartypes)
332-
with pytest.raises(ValueError):
343+
344+
with pytest.raises(ValueError) as record:
333345
transformer.transform(df_zero)
334346

347+
msg = (
348+
"Some of the reference variables contain zeroes. Division by zero "
349+
"does not exist. Replace zeros before using this transformer for division "
350+
"or set `fill_value` to a number."
351+
)
352+
# check that the error message matches
353+
assert str(record.value) == msg
354+
355+
356+
@pytest.mark.parametrize("_fill_value, _func", [
357+
(111.111, ["div"]),
358+
(999, ["div"]),
359+
(111.111, ["truediv"]),
360+
(999, ["truediv"]),
361+
(111.111, ["floordiv"]),
362+
(999, ["floordiv"]),
363+
(111.111, ["mod"]),
364+
(999, ["mod"]),
365+
])
366+
def test_fill_values_when_division_by_zero(
367+
_fill_value, _func, df_vartypes
368+
):
369+
df_zero = df_vartypes.copy()
370+
df_zero.loc[2, "Marks"] = 0
371+
df_zero.loc[1, "Age"] = np.nan
372+
df_zero.loc[3, "Age"] = np.inf
373+
374+
transformer = RelativeFeatures(
375+
variables=["Age"],
376+
reference=["Marks"],
377+
fill_value=_fill_value,
378+
func=_func,
379+
missing_values="ignore",
380+
)
381+
382+
X = transformer.fit_transform(df_zero)
383+
384+
new_var = f"Age_{_func[0]}_Marks"
385+
386+
assert X.loc[2, new_var] == _fill_value
387+
np.testing.assert_equal(X.loc[1, "Age"], np.nan)
388+
np.testing.assert_equal(X.loc[3, "Age"], np.inf)
389+
335390

336391
@pytest.mark.parametrize("_drop", [True, False])
337392
def test_get_feature_names_out(_drop, df_vartypes):

0 commit comments

Comments
 (0)