Skip to content

Commit c3da43f

Browse files
authored
API: cluster: transition to rng (SPEC 7) (scipy#21878)
* API: cluster: transition to rng (SPEC 7)
1 parent f585942 commit c3da43f

File tree

2 files changed

+47
-49
lines changed

2 files changed

+47
-49
lines changed

scipy/cluster/tests/test_vq.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -248,20 +248,21 @@ def test_large_features(self, xp):
248248
data[:x.shape[0]] = x
249249
data[x.shape[0]:] = y
250250

251-
kmeans(xp.asarray(data), 2)
251+
# use `seed` to ensure backwards compatibility after SPEC7
252+
kmeans(xp.asarray(data), 2, seed=1)
252253

253254
def test_kmeans_simple(self, xp):
254-
np.random.seed(54321)
255+
rng = np.random.default_rng(54321)
255256
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
256-
code1 = kmeans(xp.asarray(X), xp.asarray(initc), iter=1)[0]
257+
code1 = kmeans(xp.asarray(X), xp.asarray(initc), iter=1, rng=rng)[0]
257258
xp_assert_close(code1, xp.asarray(CODET2))
258259

259260
@pytest.mark.skipif(SCIPY_ARRAY_API,
260261
reason='`np.matrix` unsupported in array API mode')
261262
def test_kmeans_simple_matrix(self, xp):
262-
np.random.seed(54321)
263+
rng = np.random.default_rng(54321)
263264
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
264-
code1 = kmeans(matrix(X), matrix(initc), iter=1)[0]
265+
code1 = kmeans(matrix(X), matrix(initc), iter=1, rng=rng)[0]
265266
xp_assert_close(code1, CODET2)
266267

267268
def test_kmeans_lost_cluster(self, xp):
@@ -281,23 +282,23 @@ def test_kmeans_lost_cluster(self, xp):
281282
assert_raises(ClusterError, kmeans2, data, initk, missing='raise')
282283

283284
def test_kmeans2_simple(self, xp):
284-
np.random.seed(12345678)
285+
rng = np.random.default_rng(12345678)
285286
initc = xp.asarray(np.concatenate([[X[0]], [X[1]], [X[2]]]))
286287
arrays = [xp.asarray] if SCIPY_ARRAY_API else [np.asarray, matrix]
287288
for tp in arrays:
288-
code1 = kmeans2(tp(X), tp(initc), iter=1)[0]
289-
code2 = kmeans2(tp(X), tp(initc), iter=2)[0]
289+
code1 = kmeans2(tp(X), tp(initc), iter=1, rng=rng)[0]
290+
code2 = kmeans2(tp(X), tp(initc), iter=2, rng=rng)[0]
290291

291292
xp_assert_close(code1, xp.asarray(CODET1))
292293
xp_assert_close(code2, xp.asarray(CODET2))
293294

294295
@pytest.mark.skipif(SCIPY_ARRAY_API,
295296
reason='`np.matrix` unsupported in array API mode')
296297
def test_kmeans2_simple_matrix(self, xp):
297-
np.random.seed(12345678)
298+
rng = np.random.default_rng(12345678)
298299
initc = xp.asarray(np.concatenate([[X[0]], [X[1]], [X[2]]]))
299-
code1 = kmeans2(matrix(X), matrix(initc), iter=1)[0]
300-
code2 = kmeans2(matrix(X), matrix(initc), iter=2)[0]
300+
code1 = kmeans2(matrix(X), matrix(initc), iter=1, rng=rng)[0]
301+
code2 = kmeans2(matrix(X), matrix(initc), iter=2, rng=rng)[0]
301302

302303
xp_assert_close(code1, CODET1)
303304
xp_assert_close(code2, CODET2)
@@ -308,7 +309,9 @@ def test_kmeans2_rank1(self, xp):
308309

309310
initc = data1[:3]
310311
code = xp_copy(initc, xp=xp)
311-
kmeans2(data1, code, iter=1)[0]
312+
313+
# use `seed` to ensure backwards compatibility after SPEC7
314+
kmeans2(data1, code, iter=1, seed=1)[0]
312315
kmeans2(data1, code, iter=2)[0]
313316

314317
def test_kmeans2_rank1_2(self, xp):
@@ -326,21 +329,21 @@ def test_kmeans2_high_dim(self, xp):
326329
@skip_xp_backends('jax.numpy',
327330
reason='jax arrays do not support item assignment')
328331
def test_kmeans2_init(self, xp):
329-
np.random.seed(12345)
332+
rng = np.random.default_rng(12345678)
330333
data = xp.asarray(TESTDATA_2D)
331334
k = 3
332335

333-
kmeans2(data, k, minit='points')
334-
kmeans2(data[:, 1], k, minit='points') # special case (1-D)
336+
kmeans2(data, k, minit='points', rng=rng)
337+
kmeans2(data[:, 1], k, minit='points', rng=rng) # special case (1-D)
335338

336-
kmeans2(data, k, minit='++')
337-
kmeans2(data[:, 1], k, minit='++') # special case (1-D)
339+
kmeans2(data, k, minit='++', rng=rng)
340+
kmeans2(data[:, 1], k, minit='++', rng=rng) # special case (1-D)
338341

339342
# minit='random' can give warnings, filter those
340343
with suppress_warnings() as sup:
341344
sup.filter(message="One of the clusters is empty. Re-run.")
342-
kmeans2(data, k, minit='random')
343-
kmeans2(data[:, 1], k, minit='random') # special case (1-D)
345+
kmeans2(data, k, minit='random', rng=rng)
346+
kmeans2(data[:, 1], k, minit='random', rng=rng) # special case (1-D)
344347

345348
@pytest.mark.skipif(sys.platform == 'win32',
346349
reason='Fails with MemoryError in Wine.')
@@ -377,28 +380,29 @@ def test_kmeans_large_thres(self, xp):
377380
reason='jax arrays do not support item assignment')
378381
def test_kmeans2_kpp_low_dim(self, xp):
379382
# Regression test for gh-11462
383+
rng = np.random.default_rng(2358792345678234568)
380384
prev_res = xp.asarray([[-1.95266667, 0.898],
381385
[-3.153375, 3.3945]], dtype=xp.float64)
382-
np.random.seed(42)
383-
res, _ = kmeans2(xp.asarray(TESTDATA_2D), 2, minit='++')
386+
res, _ = kmeans2(xp.asarray(TESTDATA_2D), 2, minit='++', rng=rng)
384387
xp_assert_close(res, prev_res)
385388

386389
@skip_xp_backends('jax.numpy',
387390
reason='jax arrays do not support item assignment')
388391
def test_kmeans2_kpp_high_dim(self, xp):
389392
# Regression test for gh-11462
393+
rng = np.random.default_rng(23587923456834568)
390394
n_dim = 100
391395
size = 10
392396
centers = np.vstack([5 * np.ones(n_dim),
393397
-5 * np.ones(n_dim)])
394-
np.random.seed(42)
398+
395399
data = np.vstack([
396-
np.random.multivariate_normal(centers[0], np.eye(n_dim), size=size),
397-
np.random.multivariate_normal(centers[1], np.eye(n_dim), size=size)
400+
rng.multivariate_normal(centers[0], np.eye(n_dim), size=size),
401+
rng.multivariate_normal(centers[1], np.eye(n_dim), size=size)
398402
])
399403

400404
data = xp.asarray(data)
401-
res, _ = kmeans2(data, 2, minit='++')
405+
res, _ = kmeans2(data, 2, minit='++', rng=rng)
402406
xp_assert_equal(xp.sign(res), xp.sign(xp.asarray(centers)))
403407

404408
def test_kmeans_diff_convergence(self, xp):

scipy/cluster/vq.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@
7070
from scipy._lib._array_api import (
7171
_asarray, array_namespace, xp_size, xp_atleast_nd, xp_copy, xp_cov
7272
)
73-
from scipy._lib._util import check_random_state, rng_integers
73+
from scipy._lib._util import (check_random_state, rng_integers,
74+
_transition_to_rng)
7475
from scipy.spatial.distance import cdist
7576

7677
from . import _vq
@@ -327,8 +328,9 @@ def _kmeans(obs, guess, thresh=1e-5, xp=None):
327328
return code_book, prev_avg_dists[1]
328329

329330

331+
@_transition_to_rng("seed")
330332
def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
331-
*, seed=None):
333+
*, rng=None):
332334
"""
333335
Performs k-means on a set of observation vectors forming k clusters.
334336
@@ -374,16 +376,11 @@ def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
374376
Disabling may give a performance gain, but may result in problems
375377
(crashes, non-termination) if the inputs do contain infinities or NaNs.
376378
Default: True
377-
378-
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
379-
Seed for initializing the pseudo-random number generator.
380-
If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
381-
singleton is used.
382-
If `seed` is an int, a new ``RandomState`` instance is used,
383-
seeded with `seed`.
384-
If `seed` is already a ``Generator`` or ``RandomState`` instance then
385-
that instance is used.
386-
The default is None.
379+
rng : `numpy.random.Generator`, optional
380+
Pseudorandom number generator state. When `rng` is None, a new
381+
`numpy.random.Generator` is created using entropy from the
382+
operating system. Types other than `numpy.random.Generator` are
383+
passed to `numpy.random.default_rng` to instantiate a ``Generator``.
387384
388385
Returns
389386
-------
@@ -484,7 +481,7 @@ def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
484481
if k < 1:
485482
raise ValueError("Asked for %d clusters." % k)
486483

487-
rng = check_random_state(seed)
484+
rng = check_random_state(rng)
488485

489486
# initialize best distance value to a large value
490487
best_dist = xp.inf
@@ -645,8 +642,9 @@ def _missing_raise():
645642
_valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
646643

647644

645+
@_transition_to_rng("seed")
648646
def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
649-
missing='warn', check_finite=True, *, seed=None):
647+
missing='warn', check_finite=True, *, rng=None):
650648
"""
651649
Classify a set of observations into k clusters using the k-means algorithm.
652650
@@ -697,15 +695,11 @@ def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
697695
Disabling may give a performance gain, but may result in problems
698696
(crashes, non-termination) if the inputs do contain infinities or NaNs.
699697
Default: True
700-
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
701-
Seed for initializing the pseudo-random number generator.
702-
If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
703-
singleton is used.
704-
If `seed` is an int, a new ``RandomState`` instance is used,
705-
seeded with `seed`.
706-
If `seed` is already a ``Generator`` or ``RandomState`` instance then
707-
that instance is used.
708-
The default is None.
698+
rng : `numpy.random.Generator`, optional
699+
Pseudorandom number generator state. When `rng` is None, a new
700+
`numpy.random.Generator` is created using entropy from the
701+
operating system. Types other than `numpy.random.Generator` are
702+
passed to `numpy.random.default_rng` to instantiate a ``Generator``.
709703
710704
Returns
711705
-------
@@ -814,7 +808,7 @@ def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
814808
except KeyError as e:
815809
raise ValueError(f"Unknown init method {minit!r}") from e
816810
else:
817-
rng = check_random_state(seed)
811+
rng = check_random_state(rng)
818812
code_book = init_meth(data, code_book, rng, xp)
819813

820814
data = np.asarray(data)

0 commit comments

Comments
 (0)