Skip to content

Commit 96b8658

Browse files
committed
Fix type check to include numpy integer types
Signed-off-by: abigailt <[email protected]>
1 parent bab5aee commit 96b8658

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

art/estimators/classification/classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def nb_classes(self, nb_classes: int):
114114
"""
115115
Set the number of output classes.
116116
"""
117-
if nb_classes is None or (isinstance(nb_classes, int) and nb_classes < 2):
117+
if nb_classes is None or (isinstance(nb_classes, (int, np.integer)) and nb_classes < 2):
118118
raise ValueError("nb_classes must be greater than or equal to 2.")
119119

120120
self._nb_classes = nb_classes

art/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def check_and_transform_label_format(
801801
"""
802802
labels_return = labels
803803

804-
if nb_classes is not None and not isinstance(nb_classes, int):
804+
if nb_classes is not None and not isinstance(nb_classes, (int, np.integer)):
805805
raise TypeError("nb_classes that is not an integer is not supported")
806806

807807
if len(labels.shape) == 2 and labels.shape[1] > 1: # multi-class, one-hot encoded

0 commit comments

Comments
 (0)