Skip to content

Commit 6adef90

Browse files
authored
split function to check na into strict and optional (#608)
* split check_na in 2 * replace check for optional na * update match categories
1 parent 07de060 commit 6adef90

File tree

6 files changed

+46
-27
lines changed

6 files changed

+46
-27
lines changed

feature_engine/dataframe_checks.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _check_X_matches_training_df(X: pd.DataFrame, reference: int) -> None:
243243

244244

245245
def _check_contains_na(
246-
X: pd.DataFrame, variables: List[Union[str, int]], switch_param: bool = False
246+
X: pd.DataFrame, variables: List[Union[str, int]],
247247
) -> None:
248248
"""
249249
Checks if DataFrame contains null values in the selected columns.
@@ -255,9 +255,31 @@ def _check_contains_na(
255255
variables : List
256256
The selected group of variables in which null values will be examined.
257257
258-
switch_param: bool
259-
Whether the transformer has the parameter missing_values in the init to modify
260-
its behaviour towards nan.
258+
Raises
259+
------
260+
ValueError
261+
If the variable(s) contain null values.
262+
"""
263+
264+
if X[variables].isnull().any().any():
265+
raise ValueError(
266+
"Some of the variables in the dataset contain NaN. Check and "
267+
"remove those before using this transformer."
268+
)
269+
270+
271+
def _check_optional_contains_na(
272+
X: pd.DataFrame, variables: List[Union[str, int]]
273+
) -> None:
274+
"""
275+
Checks if DataFrame contains null values in the selected columns.
276+
277+
Parameters
278+
----------
279+
X : Pandas DataFrame
280+
281+
variables : List
282+
The selected group of variables in which null values will be examined.
261283
262284
Raises
263285
------
@@ -266,17 +288,11 @@ def _check_contains_na(
266288
"""
267289

268290
if X[variables].isnull().any().any():
269-
if switch_param is False:
270-
raise ValueError(
271-
"Some of the variables in the dataset contain NaN. Check and "
272-
"remove those before using this transformer."
273-
)
274-
else:
275-
raise ValueError(
276-
"Some of the variables in the dataset contain NaN. Check and "
277-
"remove those before using this transformer or set the parameter "
278-
"`missing_values='ignore'` when initialising this transformer."
279-
)
291+
raise ValueError(
292+
"Some of the variables in the dataset contain NaN. Check and "
293+
"remove those before using this transformer or set the parameter "
294+
"`missing_values='ignore'` when initialising this transformer."
295+
)
280296

281297

282298
def _check_contains_inf(X: pd.DataFrame, variables: List[Union[str, int]]) -> None:

feature_engine/encoding/base_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_find_or_check_categorical_variables,
2121
)
2222
from feature_engine.dataframe_checks import (
23-
_check_contains_na,
23+
_check_optional_contains_na,
2424
_check_X_matches_training_df,
2525
check_X,
2626
)
@@ -110,7 +110,7 @@ class CategoricalMethodsMixin(BaseEstimator, TransformerMixin, GetFeatureNamesOu
110110

111111
def _check_na(self, X: pd.DataFrame, variables):
112112
if self.missing_values == "raise":
113-
_check_contains_na(X, variables, switch_param=True)
113+
_check_optional_contains_na(X, variables)
114114

115115
def _check_or_select_variables(self, X: pd.DataFrame):
116116
"""
@@ -207,7 +207,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
207207

208208
# check if dataset contains na
209209
if self.missing_values == "raise":
210-
_check_contains_na(X, self.variables_, switch_param=True)
210+
_check_optional_contains_na(X, self.variables_)
211211

212212
X = self._encode(X)
213213

feature_engine/encoding/rare_label.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from feature_engine._docstrings.init_parameters.encoders import _ignore_format_docstring
2020
from feature_engine._docstrings.methods import _fit_transform_docstring
2121
from feature_engine._docstrings.substitute import Substitution
22-
from feature_engine.dataframe_checks import _check_contains_na, check_X
22+
from feature_engine.dataframe_checks import _check_optional_contains_na, check_X
2323
from feature_engine.encoding.base_encoder import (
2424
CategoricalInitMixinNA,
2525
CategoricalMethodsMixin,
@@ -244,7 +244,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
244244

245245
# check if dataset contains na
246246
if self.missing_values == "raise":
247-
_check_contains_na(X, self.variables_, switch_param=True)
247+
_check_optional_contains_na(X, self.variables_)
248248

249249
for feature in self.variables_:
250250
X[feature] = np.where(

feature_engine/encoding/similarity_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from feature_engine._docstrings.init_parameters.encoders import _ignore_format_docstring
1717
from feature_engine._docstrings.methods import _fit_transform_docstring
1818
from feature_engine._docstrings.substitute import Substitution
19-
from feature_engine.dataframe_checks import _check_contains_na, check_X
19+
from feature_engine.dataframe_checks import _check_optional_contains_na, check_X
2020
from feature_engine.encoding.base_encoder import (
2121
CategoricalInitMixin,
2222
CategoricalMethodsMixin,
@@ -241,7 +241,7 @@ def fit(self, X: pd.DataFrame, y: Optional[pd.Series] = None):
241241

242242
# if data contains nan, fail before running any logic
243243
if self.missing_values == "raise":
244-
_check_contains_na(X, variables_, switch_param=True)
244+
_check_optional_contains_na(X, variables_)
245245

246246
self.encoder_dict_ = {}
247247

@@ -311,7 +311,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
311311
check_is_fitted(self)
312312
X = self._check_transform_input_and_state(X)
313313
if self.missing_values == "raise":
314-
_check_contains_na(X, self.variables_, switch_param=True)
314+
_check_optional_contains_na(X, self.variables_)
315315

316316
new_values = []
317317
for var in self.variables_:

feature_engine/preprocessing/match_categories.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from feature_engine._docstrings.init_parameters.encoders import _ignore_format_docstring
1616
from feature_engine._docstrings.substitute import Substitution
17-
from feature_engine.dataframe_checks import _check_contains_na, check_X
17+
from feature_engine.dataframe_checks import _check_optional_contains_na, check_X
1818
from feature_engine.encoding.base_encoder import (
1919
CategoricalInitMixinNA,
2020
CategoricalMethodsMixin,
@@ -116,7 +116,7 @@ def fit(self, X: pd.DataFrame, y: Optional[pd.Series] = None):
116116
variables_ = self._check_or_select_variables(X)
117117

118118
if self.missing_values == "raise":
119-
_check_contains_na(X, variables_, switch_param=True)
119+
_check_optional_contains_na(X, variables_)
120120

121121
self.category_dict_ = dict()
122122
for var in variables_:
@@ -143,7 +143,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
143143
X = self._check_transform_input_and_state(X)
144144

145145
if self.missing_values == "raise":
146-
_check_contains_na(X, self.variables_, switch_param=True)
146+
_check_optional_contains_na(X, self.variables_)
147147

148148
for feature, levels in self.category_dict_.items():
149149
X[feature] = pd.Categorical(X[feature], levels)

tests/test_dataframe_checks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from feature_engine.dataframe_checks import (
88
_check_contains_inf,
99
_check_contains_na,
10+
_check_optional_contains_na,
1011
_check_X_matches_training_df,
1112
check_X,
1213
check_X_y,
@@ -152,14 +153,16 @@ def test_contains_na(df_na):
152153
assert _check_contains_na(df_na, ["Name", "City"])
153154
assert str(record.value) == msg
154155

156+
157+
def test_optional_contains_na(df_na):
155158
msg = (
156159
"Some of the variables in the dataset contain NaN. Check and "
157160
"remove those before using this transformer or set the parameter "
158161
"`missing_values='ignore'` when initialising this transformer."
159162
)
160163

161164
with pytest.raises(ValueError) as record:
162-
assert _check_contains_na(df_na, ["Name", "City"], switch_param=True)
165+
assert _check_optional_contains_na(df_na, ["Name", "City"])
163166
assert str(record.value) == msg
164167

165168

0 commit comments

Comments
 (0)