Skip to content

Commit 6e2bc39

Browse files
authored
Merge pull request #15 from alteryx/datachecks_features
DataChecks migration from `EvalML` into `CheckMates`
2 parents 75fccc3 + 9f2c0f0 commit 6e2bc39

29 files changed

+6226
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
.DS_Store
44
Checkmates.egg-info/
55
.python-version
6-
pdm.lock
6+
pdm.lock
7+
.pdm-python

checkmates/data_checks/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,32 @@
2323
DataCheckMessageType,
2424
)
2525
from checkmates.data_checks.checks.id_columns_data_check import IDColumnsDataCheck
26+
from checkmates.data_checks.checks.null_data_check import NullDataCheck
27+
from checkmates.data_checks.checks.class_imbalance_data_check import (
28+
ClassImbalanceDataCheck,
29+
)
30+
from checkmates.data_checks.checks.no_variance_data_check import NoVarianceDataCheck
31+
from checkmates.data_checks.checks.outliers_data_check import OutliersDataCheck
32+
from checkmates.data_checks.checks.uniqueness_data_check import UniquenessDataCheck
33+
from checkmates.data_checks.checks.ts_splitting_data_check import (
34+
TimeSeriesSplittingDataCheck,
35+
)
36+
from checkmates.data_checks.checks.ts_parameters_data_check import (
37+
TimeSeriesParametersDataCheck,
38+
)
39+
from checkmates.data_checks.checks.target_leakage_data_check import (
40+
TargetLeakageDataCheck,
41+
)
42+
from checkmates.data_checks.checks.target_distribution_data_check import (
43+
TargetDistributionDataCheck,
44+
)
45+
from checkmates.data_checks.checks.sparsity_data_check import SparsityDataCheck
46+
from checkmates.data_checks.checks.datetime_format_data_check import (
47+
DateTimeFormatDataCheck,
48+
)
49+
from checkmates.data_checks.checks.multicollinearity_data_check import (
50+
MulticollinearityDataCheck,
51+
)
52+
2653

2754
from checkmates.data_checks.datacheck_meta.utils import handle_data_check_action_code
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""Data check that checks if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds.
2+
3+
Use for classification problems.
4+
"""
5+
import numpy as np
6+
import pandas as pd
7+
8+
from checkmates.data_checks import (
9+
DataCheck,
10+
DataCheckError,
11+
DataCheckMessageCode,
12+
DataCheckWarning,
13+
)
14+
from checkmates.utils import infer_feature_types
15+
16+
17+
class ClassImbalanceDataCheck(DataCheck):
18+
"""Check if any of the target labels are imbalanced, or if the number of values for each target are below 2 times the number of CV folds. Use for classification problems.
19+
20+
Args:
21+
threshold (float): The minimum threshold allowed for class imbalance before a warning is raised.
22+
This threshold is calculated by comparing the number of samples in each class to the sum of samples in that class and the majority class.
23+
For example, a multiclass case with [900, 900, 100] samples per classes 0, 1, and 2, respectively,
24+
would have a 0.10 threshold for class 2 (100 / (900 + 100)). Defaults to 0.10.
25+
min_samples (int): The minimum number of samples per accepted class. If the minority class is both below the threshold and min_samples,
26+
then we consider this severely imbalanced. Must be greater than 0. Defaults to 100.
27+
num_cv_folds (int): The number of cross-validation folds. Must be positive. Choose 0 to ignore this warning. Defaults to 3.
28+
test_size (None, float, int): Percentage of test set size. Used to calculate class imbalance prior to splitting the
29+
data into training and validation/test sets.
30+
31+
Raises:
32+
ValueError: If threshold is not within 0 and 0.5
33+
ValueError: If min_samples is not greater than 0
34+
ValueError: If number of cv folds is negative
35+
ValueError: If test_size is not between 0 and 1
36+
"""
37+
38+
def __init__(self, threshold=0.1, min_samples=100, num_cv_folds=3, test_size=None):
39+
if threshold <= 0 or threshold > 0.5:
40+
raise ValueError(
41+
"Provided threshold {} is not within the range (0, 0.5]".format(
42+
threshold,
43+
),
44+
)
45+
self.threshold = threshold
46+
if min_samples <= 0:
47+
raise ValueError(
48+
"Provided value min_samples {} is not greater than 0".format(
49+
min_samples,
50+
),
51+
)
52+
self.min_samples = min_samples
53+
if num_cv_folds < 0:
54+
raise ValueError(
55+
"Provided number of CV folds {} is less than 0".format(num_cv_folds),
56+
)
57+
self.cv_folds = num_cv_folds * 2
58+
if test_size is not None:
59+
if not (isinstance(test_size, (int, float)) and 0 < test_size <= 1):
60+
raise ValueError(
61+
"Parameter test_size must be a number between 0 and less than or equal to 1",
62+
)
63+
self.test_size = test_size
64+
else:
65+
self.test_size = 1
66+
67+
def validate(self, X, y):
68+
"""Check if any target labels are imbalanced beyond a threshold for binary and multiclass problems.
69+
70+
Ignores NaN values in target labels if they appear.
71+
72+
Args:
73+
X (pd.DataFrame, np.ndarray): Features. Ignored.
74+
y (pd.Series, np.ndarray): Target labels to check for imbalanced data.
75+
76+
Returns:
77+
dict: Dictionary with DataCheckWarnings if imbalance in classes is less than the threshold,
78+
and DataCheckErrors if the number of values for each target is below 2 * num_cv_folds.
79+
80+
Examples:
81+
>>> import pandas as pd
82+
...
83+
>>> X = pd.DataFrame()
84+
>>> y = pd.Series([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
85+
86+
In this binary example, the target class 0 is present in fewer than 10% (threshold=0.10) of instances, and fewer than 2 * the number
87+
of cross folds (2 * 3 = 6). Therefore, both a warning and an error are returned as part of the Class Imbalance Data Check.
88+
In addition, if a target is present with fewer than `min_samples` occurrences (default is 100) and is under the threshold,
89+
a severe class imbalance warning will be raised.
90+
91+
>>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.10)
92+
>>> assert class_imb_dc.validate(X, y) == [
93+
... {
94+
... "message": "The number of instances of these targets is less than 2 * the number of cross folds = 6 instances: [0]",
95+
... "data_check_name": "ClassImbalanceDataCheck",
96+
... "level": "error",
97+
... "code": "CLASS_IMBALANCE_BELOW_FOLDS",
98+
... "details": {"target_values": [0], "rows": None, "columns": None},
99+
... "action_options": []
100+
... },
101+
... {
102+
... "message": "The following labels fall below 10% of the target: [0]",
103+
... "data_check_name": "ClassImbalanceDataCheck",
104+
... "level": "warning",
105+
... "code": "CLASS_IMBALANCE_BELOW_THRESHOLD",
106+
... "details": {"target_values": [0], "rows": None, "columns": None},
107+
... "action_options": []
108+
... },
109+
... {
110+
... "message": "The following labels in the target have severe class imbalance because they fall under 10% of the target and have less than 100 samples: [0]",
111+
... "data_check_name": "ClassImbalanceDataCheck",
112+
... "level": "warning",
113+
... "code": "CLASS_IMBALANCE_SEVERE",
114+
... "details": {"target_values": [0], "rows": None, "columns": None},
115+
... "action_options": []
116+
... }
117+
... ]
118+
119+
120+
In this multiclass example, the target class 0 is present in fewer than 30% of observations, however with 1 cv fold, the minimum
121+
number of instances required is 2 * 1 = 2. Therefore a warning, but not an error, is raised.
122+
123+
>>> y = pd.Series([0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])
124+
>>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.30, min_samples=5, num_cv_folds=1)
125+
>>> assert class_imb_dc.validate(X, y) == [
126+
... {
127+
... "message": "The following labels fall below 30% of the target: [0]",
128+
... "data_check_name": "ClassImbalanceDataCheck",
129+
... "level": "warning",
130+
... "code": "CLASS_IMBALANCE_BELOW_THRESHOLD",
131+
... "details": {"target_values": [0], "rows": None, "columns": None},
132+
... "action_options": []
133+
... },
134+
... {
135+
... "message": "The following labels in the target have severe class imbalance because they fall under 30% of the target and have less than 5 samples: [0]",
136+
... "data_check_name": "ClassImbalanceDataCheck",
137+
... "level": "warning",
138+
... "code": "CLASS_IMBALANCE_SEVERE",
139+
... "details": {"target_values": [0], "rows": None, "columns": None},
140+
... "action_options": []
141+
... }
142+
... ]
143+
...
144+
>>> y = pd.Series([0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
145+
>>> class_imb_dc = ClassImbalanceDataCheck(threshold=0.30, num_cv_folds=1)
146+
>>> assert class_imb_dc.validate(X, y) == []
147+
"""
148+
messages = []
149+
150+
original_vc = pd.Series(y).value_counts(sort=True)
151+
y = infer_feature_types(y)
152+
new_vc = y.value_counts(sort=True)
153+
if str(y.ww.logical_type) not in ["Boolean", "BooleanNullable"]:
154+
# If the inferred logical type is not in Boolean/BooleanNullable, then a
155+
# mapping to the original values is not necessary.
156+
after_to_before_inference_mapping = {new: new for new in new_vc.keys()}
157+
else:
158+
# If the inferred logical type is in Boolean/BooleanNullable, then a
159+
# mapping to the original values will be needed for the data check messages
160+
after_to_before_inference_mapping = {
161+
new: old for old, new in zip(original_vc.keys(), new_vc.keys())
162+
}
163+
# Needed for checking severe imbalance to verify values present below threshold
164+
before_to_after_inference_mapping = {
165+
old: new for new, old in after_to_before_inference_mapping.items()
166+
}
167+
168+
fold_counts = y.value_counts(normalize=False, sort=True)
169+
fold_counts = np.floor(fold_counts * self.test_size).astype(int)
170+
if len(fold_counts) == 0:
171+
return messages
172+
# search for targets that occur less than twice the number of cv folds first
173+
below_threshold_folds = fold_counts.where(fold_counts < self.cv_folds).dropna()
174+
if len(below_threshold_folds):
175+
below_threshold_values = [
176+
after_to_before_inference_mapping.get(each)
177+
for each in below_threshold_folds.index.tolist()
178+
]
179+
error_msg = "The number of instances of these targets is less than 2 * the number of cross folds = {} instances: {}"
180+
messages.append(
181+
DataCheckError(
182+
message=error_msg.format(
183+
self.cv_folds,
184+
sorted(below_threshold_values),
185+
),
186+
data_check_name=self.name,
187+
message_code=DataCheckMessageCode.CLASS_IMBALANCE_BELOW_FOLDS,
188+
details={"target_values": sorted(below_threshold_values)},
189+
).to_dict(),
190+
)
191+
192+
counts = fold_counts / (fold_counts + fold_counts.values[0])
193+
below_threshold = counts.where(counts < self.threshold).dropna()
194+
# if there are items that occur less than the threshold, add them to the list of results
195+
if len(below_threshold):
196+
below_threshold_values = [
197+
after_to_before_inference_mapping.get(each)
198+
for each in below_threshold.index.tolist()
199+
]
200+
warning_msg = "The following labels fall below {:.0f}% of the target: {}"
201+
messages.append(
202+
DataCheckWarning(
203+
message=warning_msg.format(
204+
self.threshold * 100,
205+
below_threshold_values,
206+
),
207+
data_check_name=self.name,
208+
message_code=DataCheckMessageCode.CLASS_IMBALANCE_BELOW_THRESHOLD,
209+
details={"target_values": below_threshold_values},
210+
).to_dict(),
211+
)
212+
sample_counts = fold_counts.where(fold_counts < self.min_samples).dropna()
213+
if len(below_threshold) and len(sample_counts):
214+
sample_count_values = [
215+
after_to_before_inference_mapping.get(each)
216+
for each in sample_counts.index.tolist()
217+
]
218+
severe_imbalance = [
219+
v
220+
for v in sample_count_values
221+
if before_to_after_inference_mapping.get(v) in below_threshold
222+
]
223+
warning_msg = "The following labels in the target have severe class imbalance because they fall under {:.0f}% of the target and have less than {} samples: {}"
224+
messages.append(
225+
DataCheckWarning(
226+
message=warning_msg.format(
227+
self.threshold * 100,
228+
self.min_samples,
229+
severe_imbalance,
230+
),
231+
data_check_name=self.name,
232+
message_code=DataCheckMessageCode.CLASS_IMBALANCE_SEVERE,
233+
details={"target_values": severe_imbalance},
234+
).to_dict(),
235+
)
236+
return messages

0 commit comments

Comments
 (0)