Skip to content

Commit 1f593bf

Browse files
MNT improve error message in _num_samples (scikit-learn#30355)
Co-authored-by: Loïc Estève <[email protected]>
1 parent adf74e2 commit 1f593bf

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

sklearn/tree/tests/test_tree.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copyreg
77
import io
88
import pickle
9+
import re
910
import struct
1011
from itertools import chain, product
1112

@@ -1137,7 +1138,13 @@ def test_sample_weight_invalid():
11371138
clf.fit(X, y, sample_weight=sample_weight)
11381139

11391140
sample_weight = np.array(0)
1140-
expected_err = r"Singleton.* cannot be considered a valid collection"
1141+
1142+
expected_err = re.escape(
1143+
(
1144+
"Input should have at least 1 dimension i.e. satisfy "
1145+
"`len(x.shape) > 0`, got scalar `array(0.)` instead."
1146+
)
1147+
)
11411148
with pytest.raises(TypeError, match=expected_err):
11421149
clf.fit(X, y, sample_weight=sample_weight)
11431150

sklearn/utils/tests/test_validation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,12 @@ def test_check_array_min_samples_and_features_messages():
743743
check_array([], ensure_2d=False)
744744

745745
# Invalid edge case when checking the default minimum sample of a scalar
746-
msg = r"Singleton array array\(42\) cannot be considered a valid" " collection."
746+
msg = re.escape(
747+
(
748+
"Input should have at least 1 dimension i.e. satisfy "
749+
"`len(x.shape) > 0`, got scalar `array(42)` instead."
750+
)
751+
)
747752
with pytest.raises(TypeError, match=msg):
748753
check_array(42, ensure_2d=False)
749754

sklearn/utils/validation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ def _num_samples(x):
397397
if hasattr(x, "shape") and x.shape is not None:
398398
if len(x.shape) == 0:
399399
raise TypeError(
400-
"Singleton array %r cannot be considered a valid collection." % x
400+
"Input should have at least 1 dimension i.e. satisfy "
401+
f"`len(x.shape) > 0`, got scalar `{x!r}` instead."
401402
)
402403
# Check that shape is returning an integer or default to len
403404
# Dask dataframes may not return numeric shape[0] value

0 commit comments

Comments
 (0)