Skip to content

Commit 62e7590

Browse files
authored
MAINT: stats._resampling: transition to rng (scipy#21854)
1 parent c030db5 commit 62e7590

File tree

2 files changed

+99
-106
lines changed

2 files changed

+99
-106
lines changed

scipy/stats/_resampling.py

Lines changed: 48 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from dataclasses import dataclass
88
import inspect
99

10-
from scipy._lib._util import check_random_state, _rename_parameter, rng_integers
10+
from scipy._lib._util import (check_random_state, _rename_parameter, rng_integers,
11+
_transition_to_rng)
1112
from scipy._lib._array_api import array_namespace, is_numpy, xp_moveaxis_to_end
1213
from scipy.special import ndtr, ndtri, comb, factorial
1314

@@ -60,12 +61,12 @@ def _jackknife_resample(sample, batch=None):
6061
yield resamples
6162

6263

63-
def _bootstrap_resample(sample, n_resamples=None, random_state=None):
64+
def _bootstrap_resample(sample, n_resamples=None, rng=None):
6465
"""Bootstrap resample the sample."""
6566
n = sample.shape[-1]
6667

6768
# bootstrap - each row is a random resample of original observations
68-
i = rng_integers(random_state, 0, n, (n_resamples, n))
69+
i = rng_integers(rng, 0, n, (n_resamples, n))
6970

7071
resamples = sample[..., i]
7172
return resamples
@@ -159,7 +160,7 @@ def _bca_interval(data, statistic, axis, alpha, theta_hat_b, batch):
159160

160161
def _bootstrap_iv(data, statistic, vectorized, paired, axis, confidence_level,
161162
alternative, n_resamples, batch, method, bootstrap_result,
162-
random_state):
163+
rng):
163164
"""Input validation and standardization for `bootstrap`."""
164165

165166
if vectorized not in {True, False, None}:
@@ -261,11 +262,11 @@ def statistic(i, axis=-1, data=data_iv, unpaired_statistic=statistic):
261262
and n_resamples_int == 0):
262263
raise ValueError(message)
263264

264-
random_state = check_random_state(random_state)
265+
rng = check_random_state(rng)
265266

266267
return (data_iv, statistic, vectorized, paired, axis_int,
267268
confidence_level_float, alternative, n_resamples_int, batch_iv,
268-
method, bootstrap_result, random_state)
269+
method, bootstrap_result, rng)
269270

270271

271272
@dataclass
@@ -291,10 +292,11 @@ class BootstrapResult:
291292
standard_error: float | np.ndarray
292293

293294

295+
@_transition_to_rng('random_state')
294296
def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
295297
vectorized=None, paired=False, axis=0, confidence_level=0.95,
296298
alternative='two-sided', method='BCa', bootstrap_result=None,
297-
random_state=None):
299+
rng=None):
298300
r"""
299301
Compute a two-sided bootstrap confidence interval of a statistic.
300302
@@ -393,17 +395,11 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
393395
distribution. This can be used, for example, to change
394396
`confidence_level`, change `method`, or see the effect of performing
395397
additional resampling without repeating computations.
396-
random_state : {None, int, `numpy.random.Generator`,
397-
`numpy.random.RandomState`}, optional
398-
399-
Pseudorandom number generator state used to generate resamples.
400-
401-
If `random_state` is ``None`` (or `np.random`), the
402-
`numpy.random.RandomState` singleton is used.
403-
If `random_state` is an int, a new ``RandomState`` instance is used,
404-
seeded with `random_state`.
405-
If `random_state` is already a ``Generator`` or ``RandomState``
406-
instance then that instance is used.
398+
rng : `numpy.random.Generator`, optional
399+
Pseudorandom number generator state. When `rng` is None, a new
400+
`numpy.random.Generator` is created using entropy from the
401+
operating system. Types other than `numpy.random.Generator` are
402+
passed to `numpy.random.default_rng` to instantiate a ``Generator``.
407403
408404
Returns
409405
-------
@@ -473,8 +469,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
473469
>>> import matplotlib.pyplot as plt
474470
>>> from scipy.stats import bootstrap
475471
>>> data = (data,) # samples must be in a sequence
476-
>>> res = bootstrap(data, np.std, confidence_level=0.9,
477-
... random_state=rng)
472+
>>> res = bootstrap(data, np.std, confidence_level=0.9, rng=rng)
478473
>>> fig, ax = plt.subplots()
479474
>>> ax.hist(res.bootstrap_distribution, bins=25)
480475
>>> ax.set_title('Bootstrap Distribution')
@@ -528,7 +523,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
528523
>>> for i in range(n_trials):
529524
... data = (dist.rvs(size=100, random_state=rng),)
530525
... res = bootstrap(data, np.std, confidence_level=0.9,
531-
... n_resamples=999, random_state=rng)
526+
... n_resamples=999, rng=rng)
532527
... ci = res.confidence_interval
533528
... if ci[0] < std_true < ci[1]:
534529
... ci_contains_true_std += 1
@@ -540,7 +535,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
540535
541536
>>> data = (dist.rvs(size=(n_trials, 100), random_state=rng),)
542537
>>> res = bootstrap(data, np.std, axis=-1, confidence_level=0.9,
543-
... n_resamples=999, random_state=rng)
538+
... n_resamples=999, rng=rng)
544539
>>> ci_l, ci_u = res.confidence_interval
545540
546541
Here, `ci_l` and `ci_u` contain the confidence interval for each of the
@@ -574,7 +569,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
574569
>>> sample1 = norm.rvs(scale=1, size=100, random_state=rng)
575570
>>> sample2 = norm.rvs(scale=2, size=100, random_state=rng)
576571
>>> data = (sample1, sample2)
577-
>>> res = bootstrap(data, my_statistic, method='basic', random_state=rng)
572+
>>> res = bootstrap(data, my_statistic, method='basic', rng=rng)
578573
>>> print(my_statistic(sample1, sample2))
579574
0.16661030792089523
580575
>>> print(res.confidence_interval)
@@ -603,7 +598,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
603598
604599
We call `bootstrap` using ``paired=True``.
605600
606-
>>> res = bootstrap((x, y), my_statistic, paired=True, random_state=rng)
601+
>>> res = bootstrap((x, y), my_statistic, paired=True, rng=rng)
607602
>>> print(res.confidence_interval)
608603
ConfidenceInterval(low=0.9941504301315878, high=0.996377412215445)
609604
@@ -613,15 +608,15 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
613608
>>> len(res.bootstrap_distribution)
614609
9999
615610
>>> res = bootstrap((x, y), my_statistic, paired=True,
616-
... n_resamples=1000, random_state=rng,
611+
... n_resamples=1000, rng=rng,
617612
... bootstrap_result=res)
618613
>>> len(res.bootstrap_distribution)
619614
10999
620615
621616
or to change the confidence interval options:
622617
623618
>>> res2 = bootstrap((x, y), my_statistic, paired=True,
624-
... n_resamples=0, random_state=rng, bootstrap_result=res,
619+
... n_resamples=0, rng=rng, bootstrap_result=res,
625620
... method='percentile', confidence_level=0.9)
626621
>>> np.testing.assert_equal(res2.bootstrap_distribution,
627622
... res.bootstrap_distribution)
@@ -634,10 +629,10 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
634629
# Input validation
635630
args = _bootstrap_iv(data, statistic, vectorized, paired, axis,
636631
confidence_level, alternative, n_resamples, batch,
637-
method, bootstrap_result, random_state)
632+
method, bootstrap_result, rng)
638633
(data, statistic, vectorized, paired, axis, confidence_level,
639634
alternative, n_resamples, batch, method, bootstrap_result,
640-
random_state) = args
635+
rng) = args
641636

642637
theta_hat_b = ([] if bootstrap_result is None
643638
else [bootstrap_result.bootstrap_distribution])
@@ -650,7 +645,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
650645
resampled_data = []
651646
for sample in data:
652647
resample = _bootstrap_resample(sample, n_resamples=batch_actual,
653-
random_state=random_state)
648+
rng=rng)
654649
resampled_data.append(resample)
655650

656651
# Compute bootstrap distribution of statistic
@@ -1401,34 +1396,34 @@ def _batch_generator(iterable, batch):
14011396

14021397

14031398
def _pairings_permutations_gen(n_permutations, n_samples, n_obs_sample, batch,
1404-
random_state):
1399+
rng):
14051400
# Returns a generator that yields arrays of size
14061401
# `(batch, n_samples, n_obs_sample)`.
14071402
# Each row is an independent permutation of indices 0 to `n_obs_sample`.
14081403
batch = min(batch, n_permutations)
14091404

1410-
if hasattr(random_state, 'permuted'):
1405+
if hasattr(rng, 'permuted'):
14111406
def batched_perm_generator():
14121407
indices = np.arange(n_obs_sample)
14131408
indices = np.tile(indices, (batch, n_samples, 1))
14141409
for k in range(0, n_permutations, batch):
14151410
batch_actual = min(batch, n_permutations-k)
14161411
# Don't permute in place, otherwise results depend on `batch`
1417-
permuted_indices = random_state.permuted(indices, axis=-1)
1412+
permuted_indices = rng.permuted(indices, axis=-1)
14181413
yield permuted_indices[:batch_actual]
14191414
else: # RandomState and early Generators don't have `permuted`
14201415
def batched_perm_generator():
14211416
for k in range(0, n_permutations, batch):
14221417
batch_actual = min(batch, n_permutations-k)
14231418
size = (batch_actual, n_samples, n_obs_sample)
1424-
x = random_state.random(size=size)
1419+
x = rng.random(size=size)
14251420
yield np.argsort(x, axis=-1)[:batch_actual]
14261421

14271422
return batched_perm_generator()
14281423

14291424

14301425
def _calculate_null_both(data, statistic, n_permutations, batch,
1431-
random_state=None):
1426+
rng=None):
14321427
"""
14331428
Calculate null distribution for independent sample tests.
14341429
"""
@@ -1455,7 +1450,7 @@ def _calculate_null_both(data, statistic, n_permutations, batch,
14551450
# can permute axis-slices independently. If this feature is
14561451
# added in the future, batches of the desired size should be
14571452
# generated in a single call.
1458-
perm_generator = (random_state.permutation(n_obs)
1453+
perm_generator = (rng.permutation(n_obs)
14591454
for i in range(n_permutations))
14601455

14611456
batch = batch or int(n_permutations)
@@ -1488,7 +1483,7 @@ def _calculate_null_both(data, statistic, n_permutations, batch,
14881483

14891484

14901485
def _calculate_null_pairings(data, statistic, n_permutations, batch,
1491-
random_state=None):
1486+
rng=None):
14921487
"""
14931488
Calculate null distribution for association tests.
14941489
"""
@@ -1514,7 +1509,7 @@ def _calculate_null_pairings(data, statistic, n_permutations, batch,
15141509
# Separate random permutations of indices for each sample.
15151510
# Again, it would be nice if RandomState/Generator.permutation
15161511
# could permute each axis-slice separately.
1517-
args = n_permutations, n_samples, n_obs_sample, batch, random_state
1512+
args = n_permutations, n_samples, n_obs_sample, batch, rng
15181513
batched_perm_generator = _pairings_permutations_gen(*args)
15191514

15201515
null_distribution = []
@@ -1545,7 +1540,7 @@ def _calculate_null_pairings(data, statistic, n_permutations, batch,
15451540

15461541

15471542
def _calculate_null_samples(data, statistic, n_permutations, batch,
1548-
random_state=None):
1543+
rng=None):
15491544
"""
15501545
Calculate null distribution for paired-sample tests.
15511546
"""
@@ -1572,11 +1567,11 @@ def statistic_wrapped(*data, axis):
15721567
return statistic(*data, axis=axis)
15731568

15741569
return _calculate_null_pairings(data, statistic_wrapped, n_permutations,
1575-
batch, random_state)
1570+
batch, rng)
15761571

15771572

15781573
def _permutation_test_iv(data, statistic, permutation_type, vectorized,
1579-
n_resamples, batch, alternative, axis, random_state):
1574+
n_resamples, batch, alternative, axis, rng):
15801575
"""Input validation for `permutation_test`."""
15811576

15821577
axis_int = int(axis)
@@ -1631,15 +1626,16 @@ def _permutation_test_iv(data, statistic, permutation_type, vectorized,
16311626
if alternative not in alternatives:
16321627
raise ValueError(f"`alternative` must be in {alternatives}")
16331628

1634-
random_state = check_random_state(random_state)
1629+
rng = check_random_state(rng)
16351630

16361631
return (data_iv, statistic, permutation_type, vectorized, n_resamples_int,
1637-
batch_iv, alternative, axis_int, random_state)
1632+
batch_iv, alternative, axis_int, rng)
16381633

16391634

1635+
@_transition_to_rng('random_state')
16401636
def permutation_test(data, statistic, *, permutation_type='independent',
16411637
vectorized=None, n_resamples=9999, batch=None,
1642-
alternative="two-sided", axis=0, random_state=None):
1638+
alternative="two-sided", axis=0, rng=None):
16431639
r"""
16441640
Performs a permutation test of a given statistic on provided data.
16451641
@@ -1738,17 +1734,11 @@ def permutation_test(data, statistic, *, permutation_type='independent',
17381734
statistic. If samples have a different number of dimensions,
17391735
singleton dimensions are prepended to samples with fewer dimensions
17401736
before `axis` is considered.
1741-
random_state : {None, int, `numpy.random.Generator`,
1742-
`numpy.random.RandomState`}, optional
1743-
1744-
Pseudorandom number generator state used to generate permutations.
1745-
1746-
If `random_state` is ``None`` (default), the
1747-
`numpy.random.RandomState` singleton is used.
1748-
If `random_state` is an int, a new ``RandomState`` instance is used,
1749-
seeded with `random_state`.
1750-
If `random_state` is already a ``Generator`` or ``RandomState``
1751-
instance then that instance is used.
1737+
rng : `numpy.random.Generator`, optional
1738+
Pseudorandom number generator state. When `rng` is None, a new
1739+
`numpy.random.Generator` is created using entropy from the
1740+
operating system. Types other than `numpy.random.Generator` are
1741+
passed to `numpy.random.default_rng` to instantiate a ``Generator``.
17521742
17531743
Returns
17541744
-------
@@ -1980,7 +1970,7 @@ def permutation_test(data, statistic, *, permutation_type='independent',
19801970
>>> y = norm.rvs(size=120, loc=0.2, random_state=rng)
19811971
>>> res = permutation_test((x, y), statistic, n_resamples=9999,
19821972
... vectorized=True, alternative='less',
1983-
... random_state=rng)
1973+
... rng=rng)
19841974
>>> print(res.statistic)
19851975
-0.4230459671240913
19861976
>>> print(res.pvalue)
@@ -2066,17 +2056,17 @@ def permutation_test(data, statistic, *, permutation_type='independent',
20662056
"""
20672057
args = _permutation_test_iv(data, statistic, permutation_type, vectorized,
20682058
n_resamples, batch, alternative, axis,
2069-
random_state)
2059+
rng)
20702060
(data, statistic, permutation_type, vectorized, n_resamples, batch,
2071-
alternative, axis, random_state) = args
2061+
alternative, axis, rng) = args
20722062

20732063
observed = statistic(*data, axis=-1)
20742064

20752065
null_calculators = {"pairings": _calculate_null_pairings,
20762066
"samples": _calculate_null_samples,
20772067
"independent": _calculate_null_both}
20782068
null_calculator_args = (data, statistic, n_resamples,
2079-
batch, random_state)
2069+
batch, rng)
20802070
calculate_null = null_calculators[permutation_type]
20812071
null_distribution, n_resamples, exact_test = (
20822072
calculate_null(*null_calculator_args))

0 commit comments

Comments
 (0)