1
+ """Data check that checks if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds.
2
+
3
+ Use for classification problems.
4
+ """
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from checkmates .data_checks import (
9
+ DataCheck ,
10
+ DataCheckError ,
11
+ DataCheckMessageCode ,
12
+ DataCheckWarning ,
13
+ )
14
+ from checkmates .utils import infer_feature_types
15
+
16
+
17
+ class ClassImbalanceDataCheck (DataCheck ):
18
+ """Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems.
19
+
20
+ Args:
21
+ threshold (float): The minimum threshold allowed for class imbalance before a warning is raised.
22
+ This threshold is calculated by comparing the number of samples in each class to the sum of samples in that class and the majority class.
23
+ For example, a multiclass case with [900, 900, 100] samples per classes 0, 1, and 2, respectively,
24
+ would have a 0.10 threshold for class 2 (100 / (900 + 100)). Defaults to 0.10.
25
+ min_samples (int): The minimum number of samples per accepted class. If the minority class is both below the threshold and min_samples,
26
+ then we consider this severely imbalanced. Must be greater than 0. Defaults to 100.
27
+ num_cv_folds (int): The number of cross-validation folds. Must be positive. Choose 0 to ignore this warning. Defaults to 3.
28
+ test_size (None, float, int): Percentage of test set size. Used to calculate class imbalance prior to splitting the
29
+ data into training and validation/test sets.
30
+
31
+ Raises:
32
+ ValueError: If threshold is not within 0 and 0.5
33
+ ValueError: If min_samples is not greater than 0
34
+ ValueError: If number of cv folds is negative
35
+ ValueError: If test_size is not between 0 and 1
36
+ """
37
+
38
+ def __init__ (self , threshold = 0.1 , min_samples = 100 , num_cv_folds = 3 , test_size = None ):
39
+ if threshold <= 0 or threshold > 0.5 :
40
+ raise ValueError (
41
+ "Provided threshold {} is not within the range (0, 0.5]" .format (
42
+ threshold ,
43
+ ),
44
+ )
45
+ self .threshold = threshold
46
+ if min_samples <= 0 :
47
+ raise ValueError (
48
+ "Provided value min_samples {} is not greater than 0" .format (
49
+ min_samples ,
50
+ ),
51
+ )
52
+ self .min_samples = min_samples
53
+ if num_cv_folds < 0 :
54
+ raise ValueError (
55
+ "Provided number of CV folds {} is less than 0" .format (num_cv_folds ),
56
+ )
57
+ self .cv_folds = num_cv_folds * 2
58
+ if test_size is not None :
59
+ if not (isinstance (test_size , (int , float )) and 0 < test_size <= 1 ):
60
+ raise ValueError (
61
+ "Parameter test_size must be a number between 0 and less than or equal to 1" ,
62
+ )
63
+ self .test_size = test_size
64
+ else :
65
+ self .test_size = 1
66
+
67
+ def validate (self , X , y ):
68
+ """Check if any target labels are imbalanced beyond a threshold for binary and multiclass problems.
69
+
70
+ Ignores NaN values in target labels if they appear.
71
+
72
+ Args:
73
+ X (pd.DataFrame, np.ndarray): Features. Ignored.
74
+ y (pd.Series, np.ndarray): Target labels to check for imbalanced data.
75
+
76
+ Returns:
77
+ dict: Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold,
78
+ and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.
79
+
80
+ Examples:
81
+ >>> import pandas as pd
82
+ ...
83
+ >>> X = pd.DataFrame()
84
+ >>> y = pd.Series([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
85
+
86
+ In this binary example, the target class 0 is present in fewer than 10% (threshold=0.10) of instances, and fewer than 2 * the number
87
+ of cross folds (2 * 3 = 6). Therefore, both a warning and an error are returned as part of the Class Imbalance Data Check.
88
+ In addition, if a target is present with fewer than `min_samples` occurrences (default is 100) and is under the threshold,
89
+ a severe class imbalance warning will be raised.
90
+
91
+ >>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.10)
92
+ >>> assert class_imb_dc.validate(X, y) == [
93
+ ... {
94
+ ... "message": "The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]",
95
+ ... "data_check_name": "ClassImbalanceDataCheck",
96
+ ... "level": "error",
97
+ ... "code": "CLASS_IMBALANCE_BELOW_FOLDS",
98
+ ... "details": {"target_values": [0], "rows": None, "columns": None},
99
+ ... "action_options": []
100
+ ... },
101
+ ... {
102
+ ... "message": "The following labels fall below 10% of the target: [0]",
103
+ ... "data_check_name": "ClassImbalanceDataCheck",
104
+ ... "level": "warning",
105
+ ... "code": "CLASS_IMBALANCE_BELOW_THRESHOLD",
106
+ ... "details": {"target_values": [0], "rows": None, "columns": None},
107
+ ... "action_options": []
108
+ ... },
109
+ ... {
110
+ ... "message": "The following labels in the target have severe class imbalance because they fall under 10% of the target and have less than 100 samples: [0]",
111
+ ... "data_check_name": "ClassImbalanceDataCheck",
112
+ ... "level": "warning",
113
+ ... "code": "CLASS_IMBALANCE_SEVERE",
114
+ ... "details": {"target_values": [0], "rows": None, "columns": None},
115
+ ... "action_options": []
116
+ ... }
117
+ ... ]
118
+
119
+
120
+ In this multiclass example, the target class 0 is present in fewer than 30% of observations, however with 1 cv fold, the minimum
121
+ number of instances required is 2 * 1 = 2. Therefore a warning, but not an error, is raised.
122
+
123
+ >>> y = pd.Series([0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])
124
+ >>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.30, min_samples=5, num_cv_folds=1)
125
+ >>> assert class_imb_dc.validate(X, y) == [
126
+ ... {
127
+ ... "message": "The following labels fall below 30% of the target: [0]",
128
+ ... "data_check_name": "ClassImbalanceDataCheck",
129
+ ... "level": "warning",
130
+ ... "code": "CLASS_IMBALANCE_BELOW_THRESHOLD",
131
+ ... "details": {"target_values": [0], "rows": None, "columns": None},
132
+ ... "action_options": []
133
+ ... },
134
+ ... {
135
+ ... "message": "The following labels in the target have severe class imbalance because they fall under 30% of the target and have less than 5 samples: [0]",
136
+ ... "data_check_name": "ClassImbalanceDataCheck",
137
+ ... "level": "warning",
138
+ ... "code": "CLASS_IMBALANCE_SEVERE",
139
+ ... "details": {"target_values": [0], "rows": None, "columns": None},
140
+ ... "action_options": []
141
+ ... }
142
+ ... ]
143
+ ...
144
+ >>> y = pd.Series([0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
145
+ >>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.30, num_cv_folds=1)
146
+ >>> assert class_imb_dc.validate(X, y) == []
147
+ """
148
+ messages = []
149
+
150
+ original_vc = pd .Series (y ).value_counts (sort = True )
151
+ y = infer_feature_types (y )
152
+ new_vc = y .value_counts (sort = True )
153
+ if str (y .ww .logical_type ) not in ["Boolean" , "BooleanNullable" ]:
154
+ # If the inferred logical type is not in Boolean/BooleanNullable, then a
155
+ # mapping to the original values is not necessary.
156
+ after_to_before_inference_mapping = {new : new for new in new_vc .keys ()}
157
+ else :
158
+ # If the inferred logical type is in Boolean/BooleanNullable, then a
159
+ # mapping to the original values will be needed for the data check messages
160
+ after_to_before_inference_mapping = {
161
+ new : old for old , new in zip (original_vc .keys (), new_vc .keys ())
162
+ }
163
+ # Needed for checking severe imbalance to verify values present below threshold
164
+ before_to_after_inference_mapping = {
165
+ old : new for new , old in after_to_before_inference_mapping .items ()
166
+ }
167
+
168
+ fold_counts = y .value_counts (normalize = False , sort = True )
169
+ fold_counts = np .floor (fold_counts * self .test_size ).astype (int )
170
+ if len (fold_counts ) == 0 :
171
+ return messages
172
+ # search for targets that occur less than twice the number of cv folds first
173
+ below_threshold_folds = fold_counts .where (fold_counts < self .cv_folds ).dropna ()
174
+ if len (below_threshold_folds ):
175
+ below_threshold_values = [
176
+ after_to_before_inference_mapping .get (each )
177
+ for each in below_threshold_folds .index .tolist ()
178
+ ]
179
+ error_msg = "The number of instances of these targets is less than 2 * the number of cross folds = {} instances: {}"
180
+ messages .append (
181
+ DataCheckError (
182
+ message = error_msg .format (
183
+ self .cv_folds ,
184
+ sorted (below_threshold_values ),
185
+ ),
186
+ data_check_name = self .name ,
187
+ message_code = DataCheckMessageCode .CLASS_IMBALANCE_BELOW_FOLDS ,
188
+ details = {"target_values" : sorted (below_threshold_values )},
189
+ ).to_dict (),
190
+ )
191
+
192
+ counts = fold_counts / (fold_counts + fold_counts .values [0 ])
193
+ below_threshold = counts .where (counts < self .threshold ).dropna ()
194
+ # if there are items that occur less than the threshold, add them to the list of results
195
+ if len (below_threshold ):
196
+ below_threshold_values = [
197
+ after_to_before_inference_mapping .get (each )
198
+ for each in below_threshold .index .tolist ()
199
+ ]
200
+ warning_msg = "The following labels fall below {:.0f}% of the target: {}"
201
+ messages .append (
202
+ DataCheckWarning (
203
+ message = warning_msg .format (
204
+ self .threshold * 100 ,
205
+ below_threshold_values ,
206
+ ),
207
+ data_check_name = self .name ,
208
+ message_code = DataCheckMessageCode .CLASS_IMBALANCE_BELOW_THRESHOLD ,
209
+ details = {"target_values" : below_threshold_values },
210
+ ).to_dict (),
211
+ )
212
+ sample_counts = fold_counts .where (fold_counts < self .min_samples ).dropna ()
213
+ if len (below_threshold ) and len (sample_counts ):
214
+ sample_count_values = [
215
+ after_to_before_inference_mapping .get (each )
216
+ for each in sample_counts .index .tolist ()
217
+ ]
218
+ severe_imbalance = [
219
+ v
220
+ for v in sample_count_values
221
+ if before_to_after_inference_mapping .get (v ) in below_threshold
222
+ ]
223
+ warning_msg = "The following labels in the target have severe class imbalance because they fall under {:.0f}% of the target and have less than {} samples: {}"
224
+ messages .append (
225
+ DataCheckWarning (
226
+ message = warning_msg .format (
227
+ self .threshold * 100 ,
228
+ self .min_samples ,
229
+ severe_imbalance ,
230
+ ),
231
+ data_check_name = self .name ,
232
+ message_code = DataCheckMessageCode .CLASS_IMBALANCE_SEVERE ,
233
+ details = {"target_values" : severe_imbalance },
234
+ ).to_dict (),
235
+ )
236
+ return messages
0 commit comments