@@ -68,42 +68,102 @@ def test_find_or_check_numerical_variables(df_vartypes, df_numeric_columns):
6868 assert _find_or_check_numerical_variables (df_numeric_columns , 2 ) == [2 ]
6969
7070
71- def test_find_or_check_categorical_variables (df_vartypes , df_numeric_columns ):
71+ def test_find_or_check_categorical_variables (
72+ df_vartypes , df_datetime , df_numeric_columns
73+ ):
7274 vars_cat = ["Name" , "City" ]
7375 vars_mix = ["Age" , "Marks" , "Name" ]
7476
75- assert _find_or_check_categorical_variables (df_vartypes , vars_cat ) == vars_cat
76- assert _find_or_check_categorical_variables (df_vartypes , None ) == vars_cat
77-
77+ # errors when vars entered by user are not categorical
7878 with pytest .raises (TypeError ):
7979 assert _find_or_check_categorical_variables (df_vartypes , "Marks" )
80-
80+ with pytest .raises (TypeError ):
81+ assert _find_or_check_categorical_variables (df_datetime , "datetime_range" )
82+ with pytest .raises (TypeError ):
83+ assert _find_or_check_categorical_variables (df_datetime , ["datetime_range" ])
8184 with pytest .raises (TypeError ):
8285 assert _find_or_check_categorical_variables (df_numeric_columns , 3 )
83-
8486 with pytest .raises (TypeError ):
8587 assert _find_or_check_categorical_variables (df_numeric_columns , [0 , 2 ])
86-
8788 with pytest .raises (TypeError ):
8889 assert _find_or_check_categorical_variables (df_vartypes , vars_mix )
8990
91+ # error when user enters empty list
9092 with pytest .raises (ValueError ):
9193 assert _find_or_check_categorical_variables (df_vartypes , variables = [])
9294
95+ # error when df has no categorical variables
9396 with pytest .raises (ValueError ):
9497 assert _find_or_check_categorical_variables (df_vartypes [["Age" , "Marks" ]], None )
98+ with pytest .raises (ValueError ):
99+ assert _find_or_check_categorical_variables (
100+ df_datetime [["date_obj1" , "time_obj" ]], None
101+ )
102+
103+ # when variables=None
104+ assert _find_or_check_categorical_variables (df_vartypes , None ) == vars_cat
105+ assert _find_or_check_categorical_variables (df_datetime , None ) == ["Name" ]
106+
107+ # when vars are specified
108+ assert _find_or_check_categorical_variables (df_vartypes , "Name" ) == ["Name" ]
109+ assert _find_or_check_categorical_variables (df_datetime , "date_obj1" ) == [
110+ "date_obj1"
111+ ]
112+ assert _find_or_check_categorical_variables (df_vartypes , vars_cat ) == vars_cat
113+ assert _find_or_check_categorical_variables (df_datetime , ["Name" , "date_obj1" ]) == [
114+ "Name" ,
115+ "date_obj1" ,
116+ ]
95117
118+ # vars specified, column name is integer
96119 assert _find_or_check_categorical_variables (df_numeric_columns , [0 , 1 ]) == [0 , 1 ]
97120 assert _find_or_check_categorical_variables (df_numeric_columns , 0 ) == [0 ]
98121 assert _find_or_check_categorical_variables (df_numeric_columns , 1 ) == [1 ]
99122
123+ # numeric var cast as category
100124 df_vartypes ["Age" ] = df_vartypes ["Age" ].astype ("category" )
125+ assert _find_or_check_categorical_variables (df_vartypes , "Age" ) == ["Age" ]
101126 assert _find_or_check_categorical_variables (df_vartypes , None ) == vars_cat + ["Age" ]
102127 assert _find_or_check_categorical_variables (df_vartypes , ["Name" , "Age" ]) == [
103128 "Name" ,
104129 "Age" ,
105130 ]
106131
132+ # datetime vars cast as category
133+ # object-like datetime
134+ df_datetime ["date_obj1" ] = df_datetime ["date_obj1" ].astype ("category" )
135+ assert _find_or_check_categorical_variables (df_datetime , None ) == ["Name" ]
136+ assert _find_or_check_categorical_variables (df_datetime , ["Name" , "date_obj1" ]) == [
137+ "Name" ,
138+ "date_obj1" ,
139+ ]
140+ # datetime64
141+ df_datetime ["datetime_range" ] = df_datetime ["datetime_range" ].astype ("category" )
142+ assert _find_or_check_categorical_variables (df_datetime , None ) == ["Name" ]
143+ assert _find_or_check_categorical_variables (
144+ df_datetime , ["Name" , "datetime_range" ]
145+ ) == ["Name" , "datetime_range" ]
146+
147+ # numeric var cast as object
148+ df_vartypes ["Marks" ] = df_vartypes ["Marks" ].astype ("O" )
149+ assert _find_or_check_categorical_variables (df_vartypes , "Marks" ) == ["Marks" ]
150+ assert _find_or_check_categorical_variables (df_vartypes , ["Name" , "Marks" ]) == [
151+ "Name" ,
152+ "Marks" ,
153+ ]
154+ assert _find_or_check_categorical_variables (df_vartypes , None ) == vars_cat + [
155+ "Age" ,
156+ "Marks" ,
157+ ]
158+
159+ # time-aware datetime var
160+ tz_time = pd .DataFrame (
161+ {"time_objTZ" : df_datetime ["time_obj" ].add (["+5" , "+11" , "-3" , "-8" ])}
162+ )
163+ with pytest .raises (ValueError ):
164+ assert _find_or_check_categorical_variables (tz_time , None )
165+ assert _find_or_check_categorical_variables (tz_time , "time_objTZ" ) == ["time_objTZ" ]
166+
107167
108168def test_find_or_check_datetime_variables (df_datetime ):
109169 var_dt = ["datetime_range" ]
@@ -116,7 +176,7 @@ def test_find_or_check_datetime_variables(df_datetime):
116176 {"date_obj1_cat" : df_datetime ["date_obj1" ].astype ("category" )}
117177 )
118178 tz_time = pd .DataFrame (
119- {"time_objTZ" : df_datetime ["time_obj" ].add (['+5' , ' +11' , '-3' , '-8' ])}
179+ {"time_objTZ" : df_datetime ["time_obj" ].add (["+5" , " +11" , "-3" , "-8" ])}
120180 )
121181
122182 # error when df has no datetime variables
0 commit comments