Skip to content

Commit 37405f7

Browse files
authored
add params ensure_2d and allow_nd to check_array (#722)
* add params ensure_2d and allow_nd to check_array allow to use more array with more than 2 dimensions
1 parent 6eac8a0 commit 37405f7

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

dask_ml/model_selection/_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(
148148
self.blockwise = _check_blockwise(blockwise)
149149

150150
def split(self, X, y=None, groups=None):
151-
X = check_array(X)
151+
X = check_array(X, ensure_2d=False, allow_nd=True)
152152
rng = check_random_state(self.random_state)
153153
for i in range(self.n_splits):
154154
seeds = draw_seed(rng, 0, _I4MAX, size=len(X.chunks[0]), dtype="uint")

tests/model_selection/test_split.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,14 @@ def test_split_mixed():
240240
assert len(expected) == len(results)
241241
for a, b in zip(expected, results):
242242
da.utils.assert_eq(a, b)
243+
244+
245+
def test_split_3d_data():
246+
X_3d = np.arange(1.0, 5001.0).reshape((100, 10, 5))
247+
y_3d = np.arange(1.0, 101.0).reshape(100, 1)
248+
249+
r = dask_ml.model_selection.train_test_split(X_3d, y_3d)
250+
X_train, X_test, y_train, y_test = r
251+
252+
assert X_train.ndim == X_3d.ndim
253+
assert X_train.shape[1:] == X_3d.shape[1:]

0 commit comments

Comments
 (0)