Skip to content

Commit 234c439

Browse files
authored
add validation target column type in the classification scenario (#2127)
* add validation for the classification scenario when a user input a float-type target column * use is_float_dtype * add check if the float can be converted to integer
1 parent 4aa5fb0 commit 234c439

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

responsibleai/responsibleai/rai_insights/rai_insights.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,17 @@ def _validate_rai_insights_input_parameters(
566566
f"Error finding unique values in column {column}. "
567567
"Please check your test data.")
568568

569+
# Validate that the target column isn't continuous if the
570+
# user is running classification scenario
571+
# To address error thrown from sklearn here: # noqa: E501
572+
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/multiclass.py#L197
573+
y_data = train[target_column]
574+
if (task_type == ModelTask.CLASSIFICATION and
575+
pd.api.types.is_float_dtype(y_data.dtype) and
576+
np.any(y_data != y_data.astype(int))):
577+
raise UserConfigValidationException(
578+
"Target column type must not be continuous "
579+
"for classification scenario.")
569580
# Check if any features exist that are not numeric, datetime, or
570581
# categorical.
571582
train_features = train.drop(columns=[target_column]).columns

responsibleai/tests/rai_insights/test_rai_insights_validations.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,31 @@ def test_validate_categorical_features_not_having_train_features(self):
190190
task_type='classification',
191191
categorical_features=['not_a_feature'])
192192

193+
def test_validate_multi_classification_continuous_target_column(self):
194+
raw_data = {
195+
'Column1': [10, 20, 90, 40, 50],
196+
'Column2': [10, 20, 90, 40, 50],
197+
'Target': [.1, .2, .9, .4, .5]
198+
}
199+
data = pd.DataFrame(raw_data)
200+
X_data = data.drop(columns=['Target'])
201+
X_data[TARGET] = data['Target'].values
202+
203+
# use valid target data to create the model
204+
y_train = np.array([1, 1, 2, 0, 1])
205+
model = create_lightgbm_classifier(X_data, y_train)
206+
207+
with pytest.raises(
208+
UserConfigValidationException,
209+
match="Target column type must not be continuous "
210+
"for classification scenario."):
211+
RAIInsights(
212+
model=model,
213+
train=X_data,
214+
test=X_data,
215+
target_column=TARGET,
216+
task_type='classification')
217+
193218
def test_validate_serializer(self):
194219
X_train, X_test, y_train, y_test, _, _ = \
195220
create_cancer_data(return_dataframe=True)

0 commit comments

Comments
 (0)