Skip to content

Commit cf11263

Browse files
authored
Merge pull request #2505 from abigailgold/sklearn_nbclasses
Support sklearn models with multiple outputs
2 parents 0b4bb68 + bd58b1a commit cf11263

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

art/estimators/classification/classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def nb_classes(self, nb_classes: int):
116116
"""
117117
Set the number of output classes.
118118
"""
119-
if nb_classes is None or nb_classes < 2:
119+
if nb_classes is None or (isinstance(nb_classes, (int, np.integer)) and nb_classes < 2):
120120
raise ValueError("nb_classes must be greater than or equal to 2.")
121121

122122
self._nb_classes = nb_classes

art/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,15 +799,18 @@ def check_and_transform_label_format(
799799
labels: np.ndarray, nb_classes: int | None, return_one_hot: bool = True
800800
) -> np.ndarray:
801801
"""
802-
Check label format and transform to one-hot-encoded labels if necessary
802+
Check label format and transform to one-hot-encoded labels if necessary. Only supports single-output classification.
803803
804804
:param labels: An array of integer labels of shape `(nb_samples,)`, `(nb_samples, 1)` or `(nb_samples, nb_classes)`.
805-
:param nb_classes: The number of classes. If None the number of classes is determined automatically.
805+
:param nb_classes: The number of classes, as an integer. If None the number of classes is determined automatically.
806806
:param return_one_hot: True if returning one-hot encoded labels, False if returning index labels.
807807
:return: Labels with shape `(nb_samples, nb_classes)` (one-hot) or `(nb_samples,)` (index).
808808
"""
809809
labels_return = labels
810810

811+
if nb_classes is not None and not isinstance(nb_classes, (int, np.integer)):
812+
raise TypeError("nb_classes that is not an integer is not supported")
813+
811814
if len(labels.shape) == 2 and labels.shape[1] > 1: # multi-class, one-hot encoded
812815
if not return_one_hot:
813816
labels_return = np.argmax(labels, axis=1)

tests/estimators/classification/test_scikitlearn.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ScikitlearnSVC,
4848
)
4949
from art.estimators.classification.scikitlearn import SklearnClassifier
50+
from art.utils import check_and_transform_label_format
5051

5152
from tests.utils import TestBase, master_seed
5253

@@ -80,6 +81,28 @@ def test_save(self):
8081
def test_clone_for_refitting(self):
8182
_ = self.classifier.clone_for_refitting()
8283

84+
def test_multi_label(self):
85+
x_train = self.x_train_iris
86+
y_train = self.y_train_iris
87+
x_test = self.x_test_iris
88+
y_test = self.y_test_iris
89+
90+
# make multi-label binary
91+
y_train = np.column_stack((y_train, y_train, y_train))
92+
y_train[y_train > 1] = 1
93+
y_test = np.column_stack((y_test, y_test, y_test))
94+
y_test[y_test > 1] = 1
95+
96+
underlying_model = DecisionTreeClassifier()
97+
underlying_model.fit(x_train, y_train)
98+
model = ScikitlearnDecisionTreeClassifier(model=underlying_model)
99+
100+
pred = model.predict(x_test)
101+
assert pred[0].shape[0] == x_test.shape[0]
102+
assert isinstance(model.nb_classes, np.ndarray)
103+
with self.assertRaises(TypeError):
104+
check_and_transform_label_format(y_train, nb_classes=model.nb_classes)
105+
83106

84107
class TestScikitlearnExtraTreeClassifier(TestBase):
85108
@classmethod

0 commit comments

Comments
 (0)