Skip to content

Commit a67e833

Browse files
ENH Array API for check_consistent_length (scikit-learn#29519)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent 29f6ca3 commit a67e833

File tree

6 files changed

+36
-8
lines changed

6 files changed

+36
-8
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ Tools
149149
-----
150150

151151
- :func:`model_selection.train_test_split`
152+
- :func:`utils.check_consistent_length`
152153

153154
Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub
154155
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :func:`sklearn.utils.check_consistent_length` now supports Array API compatible
2+
inputs.
3+
By :user:`Stefanie Senger <StefanieSenger>`

sklearn/metrics/cluster/_supervised.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,7 @@ def fowlkes_mallows_score(labels_true, labels_pred, *, sparse=False):
11841184
11851185
.. versionadded:: 0.18
11861186
1187-
The Fowlkes-Mallows index (FMI) is defined as the geometric mean between of
1187+
The Fowlkes-Mallows index (FMI) is defined as the geometric mean of
11881188
the precision and recall::
11891189
11901190
FMI = TP / sqrt((TP + FP) * (TP + FN))

sklearn/utils/_array_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,11 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
536536
-------
537537
namespace : module
538538
Namespace shared by array objects. If any of the `arrays` are not arrays,
539-
the namespace defaults to NumPy.
539+
the namespace defaults to the NumPy namespace.
540540
541541
is_array_api_compliant : bool
542-
True if the arrays are containers that implement the Array API spec.
542+
True if the arrays are containers that implement the array API spec (see
543+
https://data-apis.org/array-api/latest/index.html).
543544
Always False when array_api_dispatch=False.
544545
"""
545546
array_api_dispatch = get_config()["array_api_dispatch"]

sklearn/utils/tests/test_validation.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434
check_X_y,
3535
deprecated,
3636
)
37+
from sklearn.utils._array_api import yield_namespace_device_dtype_combinations
3738
from sklearn.utils._mocking import (
3839
MockDataFrame,
3940
_MockEstimatorOnOffPrediction,
4041
)
4142
from sklearn.utils._testing import (
4243
SkipTest,
4344
TempMemmap,
45+
_array_api_for_tests,
4446
_convert_container,
4547
assert_allclose,
4648
assert_allclose_dense_sparse,
@@ -1007,6 +1009,8 @@ def test_check_is_fitted_with_attributes(wrap):
10071009

10081010

10091011
def test_check_consistent_length():
1012+
"""Test that `check_consistent_length` raises on inconsistent lengths and wrong
1013+
input types trigger TypeErrors."""
10101014
check_consistent_length([1], [2], [3], [4], [5])
10111015
check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ["a", "b"])
10121016
check_consistent_length([1], (2,), np.array([3]), sp.csr_matrix((1, 2)))
@@ -1016,16 +1020,37 @@ def test_check_consistent_length():
10161020
check_consistent_length([1, 2], 1)
10171021
with pytest.raises(TypeError, match=r"got <\w+ 'object'>"):
10181022
check_consistent_length([1, 2], object())
1019-
10201023
with pytest.raises(TypeError):
10211024
check_consistent_length([1, 2], np.array(1))
1022-
10231025
# Despite ensembles having __len__ they must raise TypeError
10241026
with pytest.raises(TypeError, match="Expected sequence or array-like"):
10251027
check_consistent_length([1, 2], RandomForestRegressor())
10261028
# XXX: We should have a test with a string, but what is correct behaviour?
10271029

10281030

1031+
@pytest.mark.parametrize(
1032+
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
1033+
)
1034+
def test_check_consistent_length_array_api(array_namespace, device, _):
1035+
"""Test that check_consistent_length works with different array types."""
1036+
xp = _array_api_for_tests(array_namespace, device)
1037+
1038+
with config_context(array_api_dispatch=True):
1039+
check_consistent_length(
1040+
xp.asarray([1, 2, 3], device=device),
1041+
xp.asarray([[1, 1], [2, 2], [3, 3]], device=device),
1042+
[1, 2, 3],
1043+
["a", "b", "c"],
1044+
np.asarray(("a", "b", "c"), dtype=object),
1045+
sp.csr_array([[0, 1], [1, 0], [0, 0]]),
1046+
)
1047+
1048+
with pytest.raises(ValueError, match="inconsistent numbers of samples"):
1049+
check_consistent_length(
1050+
xp.asarray([1, 2], device=device), xp.asarray([1], device=device)
1051+
)
1052+
1053+
10291054
def test_check_dataframe_fit_attribute():
10301055
# check pandas dataframe with 'fit' column does not raise error
10311056
# https://github.com/scikit-learn/scikit-learn/issues/8415

sklearn/utils/validation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,8 @@ def check_consistent_length(*arrays):
468468
>>> b = [2, 3, 4]
469469
>>> check_consistent_length(a, b)
470470
"""
471-
472471
lengths = [_num_samples(X) for X in arrays if X is not None]
473-
uniques = np.unique(lengths)
474-
if len(uniques) > 1:
472+
if len(set(lengths)) > 1:
475473
raise ValueError(
476474
"Found input variables with inconsistent numbers of samples: %r"
477475
% [int(l) for l in lengths]

0 commit comments

Comments
 (0)