Skip to content

Commit 0d5e41a

Browse files
authored
raise error instead of warning when a user has missing data and add c… (#2143)
* raise error instead of warning when a user has missing data and add check for train data in addition to test * address comments
1 parent 9bfca4b commit 0d5e41a

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

responsibleai/responsibleai/rai_insights/rai_insights.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -597,14 +597,9 @@ def _validate_rai_insights_input_parameters(
597597
"identified as categorical features: "
598598
f"{non_categorical_or_time_string_columns}")
599599

600-
list_of_feature_having_missing_values = []
601-
for feature in test.columns.tolist():
602-
if np.any(test[feature].isnull()):
603-
list_of_feature_having_missing_values.append(feature)
604-
if len(list_of_feature_having_missing_values) > 0:
605-
warnings.warn(
606-
f"Features {list_of_feature_having_missing_values} "
607-
"have missing values in test data")
600+
# Check if any of the data is missing in test and train data
601+
self._validate_data_is_not_missing(test, "test")
602+
self._validate_data_is_not_missing(train, "train")
608603

609604
self._validate_feature_metadata(
610605
feature_metadata, train, task_type, model, target_column)
@@ -717,6 +712,17 @@ def _validate_classes(
717712
if_predictions=True
718713
)
719714

715+
def _validate_data_is_not_missing(self, data, data_name):
716+
"""Validates that data is not missing (ie null)"""
717+
list_of_feature_having_missing_values = []
718+
for feature in data.columns.tolist():
719+
if np.any(data[feature].isnull()):
720+
list_of_feature_having_missing_values.append(feature)
721+
if len(list_of_feature_having_missing_values) > 0:
722+
raise UserConfigValidationException(
723+
f"Features {list_of_feature_having_missing_values} "
724+
f"have missing values in {data_name} data.")
725+
720726
def _validate_feature_metadata(
721727
self, feature_metadata, train, task_type, model, target_column):
722728
"""Validates the feature metadata."""

responsibleai/tests/rai_insights/test_rai_insights_validations.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_validate_unsupported_task_type(self, forecasting_enabled):
4848
task_type='regre',
4949
forecasting_enabled=forecasting_enabled)
5050

51-
def test_missing_data_warnings(self):
51+
def test_missing_test_data(self):
5252
train_data = {
5353
'Column1': [10, 20, 90, 40, 50],
5454
'Column2': [10, 20, 90, 40, 50],
@@ -57,7 +57,39 @@ def test_missing_data_warnings(self):
5757
train = pd.DataFrame(train_data)
5858

5959
test_data = {
60-
'Column1': [10, 20, np.nan, 40, 50],
60+
'Column1': [10, 20, 90, 40, 50],
61+
'Column2': [10, 20, 90, 40, 50],
62+
'Target': [10, 20, np.nan, 40, 50]
63+
}
64+
test = pd.DataFrame(test_data)
65+
66+
X_train = train.drop(columns=['Target'])
67+
y_train = train['Target'].values
68+
model = create_complex_classification_pipeline(
69+
X_train, y_train, ['Column1', 'Column2'], [])
70+
71+
with pytest.raises(
72+
UserConfigValidationException,
73+
match="['Column1']") as ucve:
74+
RAIInsights(
75+
model=model,
76+
train=train,
77+
test=test,
78+
target_column='Target',
79+
task_type='classification')
80+
assert "Features ['Target'] have missing values in " + \
81+
"test data" in str(ucve.value)
82+
83+
def test_missing_train_data(self):
84+
train_data = {
85+
'Column1': [10, 20, 90, 40, 50],
86+
'Column2': [10, 20, np.nan, 40, 50],
87+
'Target': [10, 20, 90, 40, 50]
88+
}
89+
train = pd.DataFrame(train_data)
90+
91+
test_data = {
92+
'Column1': [10, 20, 90, 40, 50],
6193
'Column2': [10, 20, 90, 40, 50],
6294
'Target': [10, 20, 90, 40, 50]
6395
}
@@ -68,15 +100,17 @@ def test_missing_data_warnings(self):
68100
model = create_complex_classification_pipeline(
69101
X_train, y_train, ['Column1', 'Column2'], [])
70102

71-
with pytest.warns(
72-
UserWarning,
73-
match="['Column1']"):
103+
with pytest.raises(
104+
UserConfigValidationException,
105+
match="['Column2']") as ucve:
74106
RAIInsights(
75107
model=model,
76108
train=train,
77109
test=test,
78110
target_column='Target',
79111
task_type='classification')
112+
assert "Features ['Column2'] have missing values in " + \
113+
"train data" in str(ucve.value)
80114

81115
def test_validate_test_data_size(self):
82116
X_train, X_test, y_train, y_test, _, _ = \

0 commit comments

Comments
 (0)