Skip to content

Commit 066a0e6

Browse files
dodoargsolegalli
andauthored
Introduces check for var type when user passes var name outside a list closes #337 (#339)
* fixed issue * added couple more tests * fixes issue and rejects empty lists * changes wording and elif statement * fixed typo Co-authored-by: Soledad Galli <[email protected]>
1 parent cceeed2 commit 066a0e6

File tree

3 files changed

+76
-31
lines changed

3 files changed

+76
-31
lines changed

feature_engine/selection/target_mean_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
201201
# find categorical and numerical variables
202202
self.variables_categorical_ = list(X.select_dtypes(include="O").columns)
203203
self.variables_numerical_ = list(
204-
X.select_dtypes(include=["float", "integer"]).columns
204+
X.select_dtypes(include="number").columns
205205
)
206206

207207
# obtain cross-validation indeces

feature_engine/variable_manipulation.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
import pandas as pd
66

7+
from pandas.api.types import is_numeric_dtype as is_numeric
8+
from pandas.api.types import is_categorical_dtype as is_categorical
9+
from pandas.api.types import is_object_dtype as is_object
10+
711
Variables = Union[None, int, str, List[Union[str, int]]]
812

913

@@ -44,40 +48,47 @@ def _find_or_check_numerical_variables(
4448
4549
Parameters
4650
----------
47-
X : Pandas DataFrame
51+
X : Pandas DataFrame.
4852
variables : variable or list of variables. Defaults to None.
4953
5054
Raises
5155
------
5256
ValueError
53-
If there are no numerical variables in the df or the df is empty
57+
If there are no numerical variables in the df or the df is empty.
5458
TypeError
55-
If any of the user provided variables are not numerical
59+
If any of the user provided variables are not numerical.
5660
5761
Returns
5862
-------
59-
variables: List of numerical variables
63+
variables: List of numerical variables.
6064
"""
6165

62-
if isinstance(variables, (str, int)):
63-
variables = [variables]
64-
65-
elif not variables:
66+
if variables is None:
6667
# find numerical variables in dataset
6768
variables = list(X.select_dtypes(include="number").columns)
6869
if len(variables) == 0:
6970
raise ValueError(
70-
"No numerical variables in this dataframe. Please check variable"
71-
"format with pandas dtypes"
71+
"No numerical variables found in this dataframe. Please check "
72+
"variable format with pandas dtypes."
7273
)
7374

75+
elif isinstance(variables, (str, int)):
76+
if is_numeric(X[variables]):
77+
variables = [variables]
78+
else:
79+
raise TypeError("The variable entered is not numeric.")
80+
7481
else:
82+
if len(variables) == 0:
83+
raise ValueError("The list of variables is empty.")
84+
7585
# check that user entered variables are of type numerical
76-
if any(X[variables].select_dtypes(exclude="number").columns):
77-
raise TypeError(
78-
"Some of the variables are not numerical. Please cast them as "
79-
"numerical before using this transformer"
80-
)
86+
else:
87+
if len(X[variables].select_dtypes(exclude="number").columns) > 0:
88+
raise TypeError(
89+
"Some of the variables are not numerical. Please cast them as "
90+
"numerical before using this transformer."
91+
)
8192

8293
return variables
8394

@@ -91,38 +102,47 @@ def _find_or_check_categorical_variables(
91102
92103
Parameters
93104
----------
94-
X : pandas DataFrame
105+
X : pandas DataFrame.
95106
variables : variable or list of variables. Defaults to None.
96107
97108
Raises
98109
------
99110
ValueError
100-
If there are no categorical variables in df or df is empty
111+
If there are no categorical variables in df or df is empty.
101112
TypeError
102-
If any of the user provided variables are not categorical
113+
If any of the user provided variables are not categorical.
103114
104115
Returns
105116
-------
106-
variables : List of categorical variables
117+
variables : List of categorical variables.
107118
"""
108119

109-
if isinstance(variables, (str, int)):
110-
variables = [variables]
111-
112-
elif not variables:
120+
if variables is None:
121+
# find categorical variables in dataset
113122
variables = list(X.select_dtypes(include=["O", "category"]).columns)
114123
if len(variables) == 0:
115124
raise ValueError(
116-
"No categorical variables in this dataframe. Please check the "
117-
"variables format with pandas dtypes"
125+
"No categorical variables found in this dataframe. Please check "
126+
"variable format with pandas dtypes."
118127
)
119128

129+
elif isinstance(variables, (str, int)):
130+
if is_categorical(X[variables]) or is_object(X[variables]):
131+
variables = [variables]
132+
else:
133+
raise TypeError("The variable entered is not categorical.")
134+
120135
else:
121-
if any(X[variables].select_dtypes(exclude=["O", "category"]).columns):
122-
raise TypeError(
123-
"Some of the variables are not categorical. Please cast them as object "
124-
"or category before calling this transformer"
125-
)
136+
if len(variables) == 0:
137+
raise ValueError("The list of variables is empty.")
138+
139+
# check that user entered variables are of type numerical
140+
else:
141+
if len(X[variables].select_dtypes(exclude=["O", "category"]).columns) > 0:
142+
raise TypeError(
143+
"Some of the variables are not categorical. Please cast them as "
144+
"categorical or object before using this transformer."
145+
)
126146

127147
return variables
128148

tests/test_variable_manipulation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,21 @@ def test_find_or_check_numerical_variables(df_vartypes, df_numeric_columns):
4444
assert _find_or_check_numerical_variables(df_vartypes, var_num) == ["Age"]
4545
assert _find_or_check_numerical_variables(df_vartypes, vars_none) == vars_num
4646

47+
with pytest.raises(TypeError):
48+
assert _find_or_check_numerical_variables(df_vartypes, "City")
49+
50+
with pytest.raises(TypeError):
51+
assert _find_or_check_numerical_variables(df_numeric_columns, 0)
52+
53+
with pytest.raises(TypeError):
54+
assert _find_or_check_numerical_variables(df_numeric_columns, [1, 3])
55+
4756
with pytest.raises(TypeError):
4857
assert _find_or_check_numerical_variables(df_vartypes, vars_mix)
4958

59+
with pytest.raises(ValueError):
60+
assert _find_or_check_numerical_variables(df_vartypes, variables=[])
61+
5062
with pytest.raises(ValueError):
5163
assert _find_or_check_numerical_variables(df_vartypes[["Name", "City"]], None)
5264

@@ -61,13 +73,26 @@ def test_find_or_check_categorical_variables(df_vartypes, df_numeric_columns):
6173
assert _find_or_check_categorical_variables(df_vartypes, vars_cat) == vars_cat
6274
assert _find_or_check_categorical_variables(df_vartypes, None) == vars_cat
6375

76+
with pytest.raises(TypeError):
77+
assert _find_or_check_categorical_variables(df_vartypes, "Marks")
78+
79+
with pytest.raises(TypeError):
80+
assert _find_or_check_categorical_variables(df_numeric_columns, 3)
81+
82+
with pytest.raises(TypeError):
83+
assert _find_or_check_categorical_variables(df_numeric_columns, [0, 2])
84+
6485
with pytest.raises(TypeError):
6586
assert _find_or_check_categorical_variables(df_vartypes, vars_mix)
6687

88+
with pytest.raises(ValueError):
89+
assert _find_or_check_categorical_variables(df_vartypes, variables=[])
90+
6791
with pytest.raises(ValueError):
6892
assert _find_or_check_categorical_variables(df_vartypes[["Age", "Marks"]], None)
6993

7094
assert _find_or_check_categorical_variables(df_numeric_columns, [0, 1]) == [0, 1]
95+
assert _find_or_check_categorical_variables(df_numeric_columns, 0) == [0]
7196
assert _find_or_check_categorical_variables(df_numeric_columns, 1) == [1]
7297

7398
df_vartypes["Age"] = df_vartypes["Age"].astype("category")

0 commit comments

Comments
 (0)