diff --git a/dask_ml/datasets.py b/dask_ml/datasets.py index a561ee0d5..fb2afb831 100644 --- a/dask_ml/datasets.py +++ b/dask_ml/datasets.py @@ -460,3 +460,47 @@ def make_classification_df( ) return X_df, y_series + + +def make_s_curve( + n_samples=100, + noise=0.0, + random_state=None, + chunks=None, +): + """ + Generate an S curve dataset. + + Parameters + ---------- + n_samples : int, default=100 + The number of sample points on the S curve. + noise : float, default=0.0 + The standard deviation of the gaussian noise. + random_state : int, RandomState instance or None, default=None + Determines random number generation for dataset creation. Pass an int + for reproducible output across multiple function calls. + See :term:`Glossary `. + chunks : int + Number of rows per dask array block. + Returns + ------- + X : dask.array of shape (n_samples, 3) + The points. + t : dask.array of shape (n_samples,) + The univariate position of the sample according to the main dimension + of the points in the manifold. + """ + rng = dask_ml.utils.check_random_state(random_state) + + t_scale = 3 * np.pi * 0.5 + t = rng.uniform(low=-t_scale, high=t_scale, size=(n_samples), chunks=(chunks,)) + X = da.empty(shape=(n_samples, 3), chunks=(chunks, 3), dtype="f8") + X[:, 0] = da.sin(t) + X[:, 1] = rng.uniform(low=0, high=2, size=n_samples, chunks=(chunks,)) + X[:, 2] = da.sign(t) * (da.cos(t) - 1) + + if noise > 0: + X += rng.normal(scale=noise, size=X.shape, chunks=X.chunks) + + return X, t diff --git a/tests/test_datasets.py b/tests/test_datasets.py index d221e2963..b90979441 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -55,6 +55,7 @@ def test_make_regression(): dask_ml.datasets.make_classification, dask_ml.datasets.make_counts, dask_ml.datasets.make_regression, + dask_ml.datasets.make_s_curve, ], ) def test_deterministic(generator, scheduler): @@ -80,3 +81,19 @@ def test_make_classification_df(): assert len(X_df) == 100 assert len(y_series) == 100 assert isinstance(y_series, dask.dataframe.core.Series) + + +def test_make_s_curve(): + X, X_color = dask_ml.datasets.make_s_curve( + n_samples=200, + random_state=0, + chunks=100, + ) + + assert isinstance(X, da.Array) + assert X.shape == (200, 3) + assert X.compute().shape == X.shape + + assert isinstance(X_color, da.Array) + assert X_color.shape == (200,) + assert X_color.compute().shape == X_color.shape