Skip to content

Commit 224c734

Browse files
authored
updates datetime variable checker function to deal with cat/obj controversial cases raised in #336 (#347)
* datetime findchecker correctly deals with int/numerical vars cast as cat/obj * refactors repeated typecheck logic * wraps new tests with mark.parametrize * redesigns latest tests
1 parent 1c051bf commit 224c734

File tree

2 files changed

+89
-49
lines changed

2 files changed

+89
-49
lines changed

feature_engine/variable_manipulation.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,31 @@ def _find_or_check_numerical_variables(
9393
return variables
9494

9595

96+
def _is_convertible_to_num(column: pd.Series) -> bool:
97+
return is_numeric(pd.to_numeric(column, errors="ignore"))
98+
99+
100+
def _is_convertible_to_dt(column: pd.Series) -> bool:
101+
return is_datetime(pd.to_datetime(column, errors="ignore", utc=True))
102+
103+
104+
def _is_categories_num(column: pd.Series) -> bool:
105+
return is_numeric(column.dtype.categories)
106+
107+
96108
def _is_categorical_and_is_not_datetime(column: pd.Series) -> bool:
97109

98110
# check for datetime only if object cannot be cast as numeric because
99111
# if it could pd.to_datetime would convert it to datetime regardless
100112
if is_object(column):
101-
is_categorical_and_is_not_datetime = is_numeric(
102-
pd.to_numeric(column, errors="ignore")
103-
) or not is_datetime(pd.to_datetime(column, errors="ignore", utc=True))
113+
is_cat = _is_convertible_to_num(column) or not _is_convertible_to_dt(column)
104114

105115
# check for datetime only if the type of the categories is not numeric
106116
# because pd.to_datetime throws an error when it is an integer
107117
elif is_categorical(column):
108-
is_categorical_and_is_not_datetime = is_numeric(
109-
column.dtype.categories
110-
) or not is_datetime(pd.to_datetime(column, errors="ignore", utc=True))
118+
is_cat = _is_categories_num(column) or not _is_convertible_to_dt(column)
111119

112-
return is_categorical_and_is_not_datetime
120+
return is_cat
113121

114122

115123
def _find_or_check_categorical_variables(
@@ -170,6 +178,21 @@ def _find_or_check_categorical_variables(
170178
return variables
171179

172180

181+
def _is_categorical_and_is_datetime(column: pd.Series) -> bool:
182+
183+
# check for datetime only if object cannot be cast as numeric because
184+
# if it could pd.to_datetime would convert it to datetime regardless
185+
if is_object(column):
186+
is_dt = not _is_convertible_to_num(column) and _is_convertible_to_dt(column)
187+
188+
# check for datetime only if the type of the categories is not numeric
189+
# because pd.to_datetime throws an error when it is an integer
190+
elif is_categorical(column):
191+
is_dt = not _is_categories_num(column) and _is_convertible_to_dt(column)
192+
193+
return is_dt
194+
195+
173196
def _find_or_check_datetime_variables(
174197
X: pd.DataFrame, variables: Variables = None
175198
) -> List[Union[str, int]]:
@@ -191,8 +214,7 @@ def _find_or_check_datetime_variables(
191214
variables = [
192215
column
193216
for column in X.select_dtypes(exclude="number").columns
194-
if is_datetime(X[column])
195-
or is_datetime(pd.to_datetime(X[column], errors="ignore", utc=True))
217+
if is_datetime(X[column]) or _is_categorical_and_is_datetime(X[column])
196218
]
197219

198220
if len(variables) == 0:
@@ -202,7 +224,7 @@ def _find_or_check_datetime_variables(
202224

203225
if is_datetime(X[variables]) or (
204226
not is_numeric(X[variables])
205-
and is_datetime(pd.to_datetime(X[variables], errors="ignore", utc=True))
227+
and _is_categorical_and_is_datetime(X[variables])
206228
):
207229
variables = [variables]
208230
else:
@@ -216,14 +238,9 @@ def _find_or_check_datetime_variables(
216238
else:
217239
vars_non_dt = [
218240
column
219-
for column in variables
241+
for column in X[variables].select_dtypes(exclude="datetime")
220242
if is_numeric(X[column])
221-
or (
222-
not is_datetime(X[column])
223-
and not is_datetime(
224-
pd.to_datetime(X[column], errors="ignore", utc=True)
225-
)
226-
)
243+
or not _is_categorical_and_is_datetime(X[column])
227244
]
228245

229246
if len(vars_non_dt) > 0:

tests/test_variable_manipulation.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ def test_find_or_check_numerical_variables(df_vartypes, df_numeric_columns):
6868
assert _find_or_check_numerical_variables(df_numeric_columns, 2) == [2]
6969

7070

71+
def _cast_var_as_type(df, var, new_type):
72+
df_copy = df.copy()
73+
df_copy[var] = df[var].astype(new_type)
74+
return df_copy
75+
76+
7177
def test_find_or_check_categorical_variables(
7278
df_vartypes, df_datetime, df_numeric_columns
7379
):
@@ -120,15 +126,6 @@ def test_find_or_check_categorical_variables(
120126
assert _find_or_check_categorical_variables(df_numeric_columns, 0) == [0]
121127
assert _find_or_check_categorical_variables(df_numeric_columns, 1) == [1]
122128

123-
# numeric var cast as category
124-
df_vartypes["Age"] = df_vartypes["Age"].astype("category")
125-
assert _find_or_check_categorical_variables(df_vartypes, "Age") == ["Age"]
126-
assert _find_or_check_categorical_variables(df_vartypes, None) == vars_cat + ["Age"]
127-
assert _find_or_check_categorical_variables(df_vartypes, ["Name", "Age"]) == [
128-
"Name",
129-
"Age",
130-
]
131-
132129
# datetime vars cast as category
133130
# object-like datetime
134131
df_datetime["date_obj1"] = df_datetime["date_obj1"].astype("category")
@@ -144,18 +141,6 @@ def test_find_or_check_categorical_variables(
144141
df_datetime, ["Name", "datetime_range"]
145142
) == ["Name", "datetime_range"]
146143

147-
# numeric var cast as object
148-
df_vartypes["Marks"] = df_vartypes["Marks"].astype("O")
149-
assert _find_or_check_categorical_variables(df_vartypes, "Marks") == ["Marks"]
150-
assert _find_or_check_categorical_variables(df_vartypes, ["Name", "Marks"]) == [
151-
"Name",
152-
"Marks",
153-
]
154-
assert _find_or_check_categorical_variables(df_vartypes, None) == vars_cat + [
155-
"Age",
156-
"Marks",
157-
]
158-
159144
# time-aware datetime var
160145
tz_time = pd.DataFrame(
161146
{"time_objTZ": df_datetime["time_obj"].add(["+5", "+11", "-3", "-8"])}
@@ -165,16 +150,33 @@ def test_find_or_check_categorical_variables(
165150
assert _find_or_check_categorical_variables(tz_time, "time_objTZ") == ["time_objTZ"]
166151

167152

153+
@pytest.mark.parametrize(
154+
"_num_var, _cat_type",
155+
[("Age", "category"), ("Age", "O"), ("Marks", "category"), ("Marks", "O")],
156+
)
157+
def test_find_or_check_categorical_variables_when_numeric_is_cast_as_category_or_object(
158+
df_vartypes, _num_var, _cat_type
159+
):
160+
df_vartypes = _cast_var_as_type(df_vartypes, _num_var, _cat_type)
161+
assert _find_or_check_categorical_variables(df_vartypes, _num_var) == [_num_var]
162+
assert _find_or_check_categorical_variables(df_vartypes, None) == [
163+
"Name",
164+
"City",
165+
_num_var,
166+
]
167+
assert _find_or_check_categorical_variables(df_vartypes, ["Name", _num_var]) == [
168+
"Name",
169+
_num_var,
170+
]
171+
172+
168173
def test_find_or_check_datetime_variables(df_datetime):
169174
var_dt = ["datetime_range"]
170175
var_dt_str = "datetime_range"
171176
vars_nondt = ["Age", "Name"]
172177
vars_convertible_to_dt = ["datetime_range", "date_obj1", "date_obj2", "time_obj"]
173178
var_convertible_to_dt = "date_obj1"
174179
vars_mix = ["datetime_range", "Age"]
175-
cat_date = pd.DataFrame(
176-
{"date_obj1_cat": df_datetime["date_obj1"].astype("category")}
177-
)
178180
tz_time = pd.DataFrame(
179181
{"time_objTZ": df_datetime["time_obj"].add(["+5", "+11", "-3", "-8"])}
180182
)
@@ -229,19 +231,40 @@ def test_find_or_check_datetime_variables(df_datetime):
229231
)
230232
== vars_convertible_to_dt + ["time_objTZ"]
231233
)
232-
# vars cast as categorical
233-
assert _find_or_check_datetime_variables(cat_date, variables="date_obj1_cat") == [
234-
"date_obj1_cat"
234+
235+
# datetime var cast as categorical
236+
df_datetime["date_obj1"] = df_datetime["date_obj1"].astype("category")
237+
assert _find_or_check_datetime_variables(df_datetime, variables="date_obj1") == [
238+
"date_obj1"
235239
]
236240
assert (
237-
_find_or_check_datetime_variables(
238-
df_datetime[vars_convertible_to_dt].join(cat_date),
239-
variables=vars_convertible_to_dt + ["date_obj1_cat"],
240-
)
241-
== vars_convertible_to_dt + ["date_obj1_cat"]
241+
_find_or_check_datetime_variables(df_datetime, variables=vars_convertible_to_dt)
242+
== vars_convertible_to_dt
242243
)
243244

244245

246+
@pytest.mark.parametrize("_num_var, _cat_type", [("Age", "category"), ("Age", "O")])
247+
def test_find_or_check_datetime_variables_when_numeric_is_cast_as_category_or_object(
248+
df_datetime, _num_var, _cat_type
249+
):
250+
df_datetime = _cast_var_as_type(df_datetime, _num_var, _cat_type)
251+
with pytest.raises(TypeError):
252+
assert _find_or_check_datetime_variables(df_datetime, variables=_num_var)
253+
with pytest.raises(TypeError):
254+
assert _find_or_check_datetime_variables(df_datetime, variables=[_num_var])
255+
with pytest.raises(ValueError) as errinfo:
256+
assert _find_or_check_datetime_variables(
257+
df_datetime[[_num_var]], variables=None
258+
)
259+
assert str(errinfo.value) == "No datetime variables found in this dataframe."
260+
assert _find_or_check_datetime_variables(df_datetime, variables=None) == [
261+
"datetime_range",
262+
"date_obj1",
263+
"date_obj2",
264+
"time_obj",
265+
]
266+
267+
245268
def test_find_all_variables(df_vartypes):
246269
all_vars = ["Name", "City", "Age", "Marks", "dob"]
247270
user_vars = ["Name", "City"]

0 commit comments

Comments
 (0)