@@ -68,6 +68,12 @@ def test_find_or_check_numerical_variables(df_vartypes, df_numeric_columns):
6868 assert _find_or_check_numerical_variables (df_numeric_columns , 2 ) == [2 ]
6969
7070
71+ def _cast_var_as_type (df , var , new_type ):
72+ df_copy = df .copy ()
73+ df_copy [var ] = df [var ].astype (new_type )
74+ return df_copy
75+
76+
7177def test_find_or_check_categorical_variables (
7278 df_vartypes , df_datetime , df_numeric_columns
7379):
@@ -120,15 +126,6 @@ def test_find_or_check_categorical_variables(
120126 assert _find_or_check_categorical_variables (df_numeric_columns , 0 ) == [0 ]
121127 assert _find_or_check_categorical_variables (df_numeric_columns , 1 ) == [1 ]
122128
123- # numeric var cast as category
124- df_vartypes ["Age" ] = df_vartypes ["Age" ].astype ("category" )
125- assert _find_or_check_categorical_variables (df_vartypes , "Age" ) == ["Age" ]
126- assert _find_or_check_categorical_variables (df_vartypes , None ) == vars_cat + ["Age" ]
127- assert _find_or_check_categorical_variables (df_vartypes , ["Name" , "Age" ]) == [
128- "Name" ,
129- "Age" ,
130- ]
131-
132129 # datetime vars cast as category
133130 # object-like datetime
134131 df_datetime ["date_obj1" ] = df_datetime ["date_obj1" ].astype ("category" )
@@ -144,18 +141,6 @@ def test_find_or_check_categorical_variables(
144141 df_datetime , ["Name" , "datetime_range" ]
145142 ) == ["Name" , "datetime_range" ]
146143
147- # numeric var cast as object
148- df_vartypes ["Marks" ] = df_vartypes ["Marks" ].astype ("O" )
149- assert _find_or_check_categorical_variables (df_vartypes , "Marks" ) == ["Marks" ]
150- assert _find_or_check_categorical_variables (df_vartypes , ["Name" , "Marks" ]) == [
151- "Name" ,
152- "Marks" ,
153- ]
154- assert _find_or_check_categorical_variables (df_vartypes , None ) == vars_cat + [
155- "Age" ,
156- "Marks" ,
157- ]
158-
159144 # time-aware datetime var
160145 tz_time = pd .DataFrame (
161146 {"time_objTZ" : df_datetime ["time_obj" ].add (["+5" , "+11" , "-3" , "-8" ])}
@@ -165,16 +150,33 @@ def test_find_or_check_categorical_variables(
165150 assert _find_or_check_categorical_variables (tz_time , "time_objTZ" ) == ["time_objTZ" ]
166151
167152
153+ @pytest .mark .parametrize (
154+ "_num_var, _cat_type" ,
155+ [("Age" , "category" ), ("Age" , "O" ), ("Marks" , "category" ), ("Marks" , "O" )],
156+ )
157+ def test_find_or_check_categorical_variables_when_numeric_is_cast_as_category_or_object (
158+ df_vartypes , _num_var , _cat_type
159+ ):
160+ df_vartypes = _cast_var_as_type (df_vartypes , _num_var , _cat_type )
161+ assert _find_or_check_categorical_variables (df_vartypes , _num_var ) == [_num_var ]
162+ assert _find_or_check_categorical_variables (df_vartypes , None ) == [
163+ "Name" ,
164+ "City" ,
165+ _num_var ,
166+ ]
167+ assert _find_or_check_categorical_variables (df_vartypes , ["Name" , _num_var ]) == [
168+ "Name" ,
169+ _num_var ,
170+ ]
171+
172+
168173def test_find_or_check_datetime_variables (df_datetime ):
169174 var_dt = ["datetime_range" ]
170175 var_dt_str = "datetime_range"
171176 vars_nondt = ["Age" , "Name" ]
172177 vars_convertible_to_dt = ["datetime_range" , "date_obj1" , "date_obj2" , "time_obj" ]
173178 var_convertible_to_dt = "date_obj1"
174179 vars_mix = ["datetime_range" , "Age" ]
175- cat_date = pd .DataFrame (
176- {"date_obj1_cat" : df_datetime ["date_obj1" ].astype ("category" )}
177- )
178180 tz_time = pd .DataFrame (
179181 {"time_objTZ" : df_datetime ["time_obj" ].add (["+5" , "+11" , "-3" , "-8" ])}
180182 )
@@ -229,19 +231,40 @@ def test_find_or_check_datetime_variables(df_datetime):
229231 )
230232 == vars_convertible_to_dt + ["time_objTZ" ]
231233 )
232- # vars cast as categorical
233- assert _find_or_check_datetime_variables (cat_date , variables = "date_obj1_cat" ) == [
234- "date_obj1_cat"
234+
235+ # datetime var cast as categorical
236+ df_datetime ["date_obj1" ] = df_datetime ["date_obj1" ].astype ("category" )
237+ assert _find_or_check_datetime_variables (df_datetime , variables = "date_obj1" ) == [
238+ "date_obj1"
235239 ]
236240 assert (
237- _find_or_check_datetime_variables (
238- df_datetime [vars_convertible_to_dt ].join (cat_date ),
239- variables = vars_convertible_to_dt + ["date_obj1_cat" ],
240- )
241- == vars_convertible_to_dt + ["date_obj1_cat" ]
241+ _find_or_check_datetime_variables (df_datetime , variables = vars_convertible_to_dt )
242+ == vars_convertible_to_dt
242243 )
243244
244245
246+ @pytest .mark .parametrize ("_num_var, _cat_type" , [("Age" , "category" ), ("Age" , "O" )])
247+ def test_find_or_check_datetime_variables_when_numeric_is_cast_as_category_or_object (
248+ df_datetime , _num_var , _cat_type
249+ ):
250+ df_datetime = _cast_var_as_type (df_datetime , _num_var , _cat_type )
251+ with pytest .raises (TypeError ):
252+ assert _find_or_check_datetime_variables (df_datetime , variables = _num_var )
253+ with pytest .raises (TypeError ):
254+ assert _find_or_check_datetime_variables (df_datetime , variables = [_num_var ])
255+ with pytest .raises (ValueError ) as errinfo :
256+ assert _find_or_check_datetime_variables (
257+ df_datetime [[_num_var ]], variables = None
258+ )
259+ assert str (errinfo .value ) == "No datetime variables found in this dataframe."
260+ assert _find_or_check_datetime_variables (df_datetime , variables = None ) == [
261+ "datetime_range" ,
262+ "date_obj1" ,
263+ "date_obj2" ,
264+ "time_obj" ,
265+ ]
266+
267+
245268def test_find_all_variables (df_vartypes ):
246269 all_vars = ["Name" , "City" , "Age" , "Marks" , "dob" ]
247270 user_vars = ["Name" , "City" ]
0 commit comments