7
7
from dataclasses import dataclass
8
8
import inspect
9
9
10
- from scipy ._lib ._util import check_random_state , _rename_parameter , rng_integers
10
+ from scipy ._lib ._util import (check_random_state , _rename_parameter , rng_integers ,
11
+ _transition_to_rng )
11
12
from scipy ._lib ._array_api import array_namespace , is_numpy , xp_moveaxis_to_end
12
13
from scipy .special import ndtr , ndtri , comb , factorial
13
14
@@ -60,12 +61,12 @@ def _jackknife_resample(sample, batch=None):
60
61
yield resamples
61
62
62
63
63
- def _bootstrap_resample (sample , n_resamples = None , random_state = None ):
64
+ def _bootstrap_resample (sample , n_resamples = None , rng = None ):
64
65
"""Bootstrap resample the sample."""
65
66
n = sample .shape [- 1 ]
66
67
67
68
# bootstrap - each row is a random resample of original observations
68
- i = rng_integers (random_state , 0 , n , (n_resamples , n ))
69
+ i = rng_integers (rng , 0 , n , (n_resamples , n ))
69
70
70
71
resamples = sample [..., i ]
71
72
return resamples
@@ -159,7 +160,7 @@ def _bca_interval(data, statistic, axis, alpha, theta_hat_b, batch):
159
160
160
161
def _bootstrap_iv (data , statistic , vectorized , paired , axis , confidence_level ,
161
162
alternative , n_resamples , batch , method , bootstrap_result ,
162
- random_state ):
163
+ rng ):
163
164
"""Input validation and standardization for `bootstrap`."""
164
165
165
166
if vectorized not in {True , False , None }:
@@ -261,11 +262,11 @@ def statistic(i, axis=-1, data=data_iv, unpaired_statistic=statistic):
261
262
and n_resamples_int == 0 ):
262
263
raise ValueError (message )
263
264
264
- random_state = check_random_state (random_state )
265
+ rng = check_random_state (rng )
265
266
266
267
return (data_iv , statistic , vectorized , paired , axis_int ,
267
268
confidence_level_float , alternative , n_resamples_int , batch_iv ,
268
- method , bootstrap_result , random_state )
269
+ method , bootstrap_result , rng )
269
270
270
271
271
272
@dataclass
@@ -291,10 +292,11 @@ class BootstrapResult:
291
292
standard_error : float | np .ndarray
292
293
293
294
295
+ @_transition_to_rng ('random_state' )
294
296
def bootstrap (data , statistic , * , n_resamples = 9999 , batch = None ,
295
297
vectorized = None , paired = False , axis = 0 , confidence_level = 0.95 ,
296
298
alternative = 'two-sided' , method = 'BCa' , bootstrap_result = None ,
297
- random_state = None ):
299
+ rng = None ):
298
300
r"""
299
301
Compute a two-sided bootstrap confidence interval of a statistic.
300
302
@@ -393,17 +395,11 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
393
395
distribution. This can be used, for example, to change
394
396
`confidence_level`, change `method`, or see the effect of performing
395
397
additional resampling without repeating computations.
396
- random_state : {None, int, `numpy.random.Generator`,
397
- `numpy.random.RandomState`}, optional
398
-
399
- Pseudorandom number generator state used to generate resamples.
400
-
401
- If `random_state` is ``None`` (or `np.random`), the
402
- `numpy.random.RandomState` singleton is used.
403
- If `random_state` is an int, a new ``RandomState`` instance is used,
404
- seeded with `random_state`.
405
- If `random_state` is already a ``Generator`` or ``RandomState``
406
- instance then that instance is used.
398
+ rng : `numpy.random.Generator`, optional
399
+ Pseudorandom number generator state. When `rng` is None, a new
400
+ `numpy.random.Generator` is created using entropy from the
401
+ operating system. Types other than `numpy.random.Generator` are
402
+ passed to `numpy.random.default_rng` to instantiate a ``Generator``.
407
403
408
404
Returns
409
405
-------
@@ -473,8 +469,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
473
469
>>> import matplotlib.pyplot as plt
474
470
>>> from scipy.stats import bootstrap
475
471
>>> data = (data,) # samples must be in a sequence
476
- >>> res = bootstrap(data, np.std, confidence_level=0.9,
477
- ... random_state=rng)
472
+ >>> res = bootstrap(data, np.std, confidence_level=0.9, rng=rng)
478
473
>>> fig, ax = plt.subplots()
479
474
>>> ax.hist(res.bootstrap_distribution, bins=25)
480
475
>>> ax.set_title('Bootstrap Distribution')
@@ -528,7 +523,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
528
523
>>> for i in range(n_trials):
529
524
... data = (dist.rvs(size=100, random_state=rng),)
530
525
... res = bootstrap(data, np.std, confidence_level=0.9,
531
- ... n_resamples=999, random_state =rng)
526
+ ... n_resamples=999, rng =rng)
532
527
... ci = res.confidence_interval
533
528
... if ci[0] < std_true < ci[1]:
534
529
... ci_contains_true_std += 1
@@ -540,7 +535,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
540
535
541
536
>>> data = (dist.rvs(size=(n_trials, 100), random_state=rng),)
542
537
>>> res = bootstrap(data, np.std, axis=-1, confidence_level=0.9,
543
- ... n_resamples=999, random_state =rng)
538
+ ... n_resamples=999, rng =rng)
544
539
>>> ci_l, ci_u = res.confidence_interval
545
540
546
541
Here, `ci_l` and `ci_u` contain the confidence interval for each of the
@@ -574,7 +569,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
574
569
>>> sample1 = norm.rvs(scale=1, size=100, random_state=rng)
575
570
>>> sample2 = norm.rvs(scale=2, size=100, random_state=rng)
576
571
>>> data = (sample1, sample2)
577
- >>> res = bootstrap(data, my_statistic, method='basic', random_state =rng)
572
+ >>> res = bootstrap(data, my_statistic, method='basic', rng =rng)
578
573
>>> print(my_statistic(sample1, sample2))
579
574
0.16661030792089523
580
575
>>> print(res.confidence_interval)
@@ -603,7 +598,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
603
598
604
599
We call `bootstrap` using ``paired=True``.
605
600
606
- >>> res = bootstrap((x, y), my_statistic, paired=True, random_state =rng)
601
+ >>> res = bootstrap((x, y), my_statistic, paired=True, rng =rng)
607
602
>>> print(res.confidence_interval)
608
603
ConfidenceInterval(low=0.9941504301315878, high=0.996377412215445)
609
604
@@ -613,15 +608,15 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
613
608
>>> len(res.bootstrap_distribution)
614
609
9999
615
610
>>> res = bootstrap((x, y), my_statistic, paired=True,
616
- ... n_resamples=1000, random_state =rng,
611
+ ... n_resamples=1000, rng =rng,
617
612
... bootstrap_result=res)
618
613
>>> len(res.bootstrap_distribution)
619
614
10999
620
615
621
616
or to change the confidence interval options:
622
617
623
618
>>> res2 = bootstrap((x, y), my_statistic, paired=True,
624
- ... n_resamples=0, random_state =rng, bootstrap_result=res,
619
+ ... n_resamples=0, rng =rng, bootstrap_result=res,
625
620
... method='percentile', confidence_level=0.9)
626
621
>>> np.testing.assert_equal(res2.bootstrap_distribution,
627
622
... res.bootstrap_distribution)
@@ -634,10 +629,10 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
634
629
# Input validation
635
630
args = _bootstrap_iv (data , statistic , vectorized , paired , axis ,
636
631
confidence_level , alternative , n_resamples , batch ,
637
- method , bootstrap_result , random_state )
632
+ method , bootstrap_result , rng )
638
633
(data , statistic , vectorized , paired , axis , confidence_level ,
639
634
alternative , n_resamples , batch , method , bootstrap_result ,
640
- random_state ) = args
635
+ rng ) = args
641
636
642
637
theta_hat_b = ([] if bootstrap_result is None
643
638
else [bootstrap_result .bootstrap_distribution ])
@@ -650,7 +645,7 @@ def bootstrap(data, statistic, *, n_resamples=9999, batch=None,
650
645
resampled_data = []
651
646
for sample in data :
652
647
resample = _bootstrap_resample (sample , n_resamples = batch_actual ,
653
- random_state = random_state )
648
+ rng = rng )
654
649
resampled_data .append (resample )
655
650
656
651
# Compute bootstrap distribution of statistic
@@ -1401,34 +1396,34 @@ def _batch_generator(iterable, batch):
1401
1396
1402
1397
1403
1398
def _pairings_permutations_gen (n_permutations , n_samples , n_obs_sample , batch ,
1404
- random_state ):
1399
+ rng ):
1405
1400
# Returns a generator that yields arrays of size
1406
1401
# `(batch, n_samples, n_obs_sample)`.
1407
1402
# Each row is an independent permutation of indices 0 to `n_obs_sample`.
1408
1403
batch = min (batch , n_permutations )
1409
1404
1410
- if hasattr (random_state , 'permuted' ):
1405
+ if hasattr (rng , 'permuted' ):
1411
1406
def batched_perm_generator ():
1412
1407
indices = np .arange (n_obs_sample )
1413
1408
indices = np .tile (indices , (batch , n_samples , 1 ))
1414
1409
for k in range (0 , n_permutations , batch ):
1415
1410
batch_actual = min (batch , n_permutations - k )
1416
1411
# Don't permute in place, otherwise results depend on `batch`
1417
- permuted_indices = random_state .permuted (indices , axis = - 1 )
1412
+ permuted_indices = rng .permuted (indices , axis = - 1 )
1418
1413
yield permuted_indices [:batch_actual ]
1419
1414
else : # RandomState and early Generators don't have `permuted`
1420
1415
def batched_perm_generator ():
1421
1416
for k in range (0 , n_permutations , batch ):
1422
1417
batch_actual = min (batch , n_permutations - k )
1423
1418
size = (batch_actual , n_samples , n_obs_sample )
1424
- x = random_state .random (size = size )
1419
+ x = rng .random (size = size )
1425
1420
yield np .argsort (x , axis = - 1 )[:batch_actual ]
1426
1421
1427
1422
return batched_perm_generator ()
1428
1423
1429
1424
1430
1425
def _calculate_null_both (data , statistic , n_permutations , batch ,
1431
- random_state = None ):
1426
+ rng = None ):
1432
1427
"""
1433
1428
Calculate null distribution for independent sample tests.
1434
1429
"""
@@ -1455,7 +1450,7 @@ def _calculate_null_both(data, statistic, n_permutations, batch,
1455
1450
# can permute axis-slices independently. If this feature is
1456
1451
# added in the future, batches of the desired size should be
1457
1452
# generated in a single call.
1458
- perm_generator = (random_state .permutation (n_obs )
1453
+ perm_generator = (rng .permutation (n_obs )
1459
1454
for i in range (n_permutations ))
1460
1455
1461
1456
batch = batch or int (n_permutations )
@@ -1488,7 +1483,7 @@ def _calculate_null_both(data, statistic, n_permutations, batch,
1488
1483
1489
1484
1490
1485
def _calculate_null_pairings (data , statistic , n_permutations , batch ,
1491
- random_state = None ):
1486
+ rng = None ):
1492
1487
"""
1493
1488
Calculate null distribution for association tests.
1494
1489
"""
@@ -1514,7 +1509,7 @@ def _calculate_null_pairings(data, statistic, n_permutations, batch,
1514
1509
# Separate random permutations of indices for each sample.
1515
1510
# Again, it would be nice if RandomState/Generator.permutation
1516
1511
# could permute each axis-slice separately.
1517
- args = n_permutations , n_samples , n_obs_sample , batch , random_state
1512
+ args = n_permutations , n_samples , n_obs_sample , batch , rng
1518
1513
batched_perm_generator = _pairings_permutations_gen (* args )
1519
1514
1520
1515
null_distribution = []
@@ -1545,7 +1540,7 @@ def _calculate_null_pairings(data, statistic, n_permutations, batch,
1545
1540
1546
1541
1547
1542
def _calculate_null_samples (data , statistic , n_permutations , batch ,
1548
- random_state = None ):
1543
+ rng = None ):
1549
1544
"""
1550
1545
Calculate null distribution for paired-sample tests.
1551
1546
"""
@@ -1572,11 +1567,11 @@ def statistic_wrapped(*data, axis):
1572
1567
return statistic (* data , axis = axis )
1573
1568
1574
1569
return _calculate_null_pairings (data , statistic_wrapped , n_permutations ,
1575
- batch , random_state )
1570
+ batch , rng )
1576
1571
1577
1572
1578
1573
def _permutation_test_iv (data , statistic , permutation_type , vectorized ,
1579
- n_resamples , batch , alternative , axis , random_state ):
1574
+ n_resamples , batch , alternative , axis , rng ):
1580
1575
"""Input validation for `permutation_test`."""
1581
1576
1582
1577
axis_int = int (axis )
@@ -1631,15 +1626,16 @@ def _permutation_test_iv(data, statistic, permutation_type, vectorized,
1631
1626
if alternative not in alternatives :
1632
1627
raise ValueError (f"`alternative` must be in { alternatives } " )
1633
1628
1634
- random_state = check_random_state (random_state )
1629
+ rng = check_random_state (rng )
1635
1630
1636
1631
return (data_iv , statistic , permutation_type , vectorized , n_resamples_int ,
1637
- batch_iv , alternative , axis_int , random_state )
1632
+ batch_iv , alternative , axis_int , rng )
1638
1633
1639
1634
1635
+ @_transition_to_rng ('random_state' )
1640
1636
def permutation_test (data , statistic , * , permutation_type = 'independent' ,
1641
1637
vectorized = None , n_resamples = 9999 , batch = None ,
1642
- alternative = "two-sided" , axis = 0 , random_state = None ):
1638
+ alternative = "two-sided" , axis = 0 , rng = None ):
1643
1639
r"""
1644
1640
Performs a permutation test of a given statistic on provided data.
1645
1641
@@ -1738,17 +1734,11 @@ def permutation_test(data, statistic, *, permutation_type='independent',
1738
1734
statistic. If samples have a different number of dimensions,
1739
1735
singleton dimensions are prepended to samples with fewer dimensions
1740
1736
before `axis` is considered.
1741
- random_state : {None, int, `numpy.random.Generator`,
1742
- `numpy.random.RandomState`}, optional
1743
-
1744
- Pseudorandom number generator state used to generate permutations.
1745
-
1746
- If `random_state` is ``None`` (default), the
1747
- `numpy.random.RandomState` singleton is used.
1748
- If `random_state` is an int, a new ``RandomState`` instance is used,
1749
- seeded with `random_state`.
1750
- If `random_state` is already a ``Generator`` or ``RandomState``
1751
- instance then that instance is used.
1737
+ rng : `numpy.random.Generator`, optional
1738
+ Pseudorandom number generator state. When `rng` is None, a new
1739
+ `numpy.random.Generator` is created using entropy from the
1740
+ operating system. Types other than `numpy.random.Generator` are
1741
+ passed to `numpy.random.default_rng` to instantiate a ``Generator``.
1752
1742
1753
1743
Returns
1754
1744
-------
@@ -1980,7 +1970,7 @@ def permutation_test(data, statistic, *, permutation_type='independent',
1980
1970
>>> y = norm.rvs(size=120, loc=0.2, random_state=rng)
1981
1971
>>> res = permutation_test((x, y), statistic, n_resamples=9999,
1982
1972
... vectorized=True, alternative='less',
1983
- ... random_state =rng)
1973
+ ... rng =rng)
1984
1974
>>> print(res.statistic)
1985
1975
-0.4230459671240913
1986
1976
>>> print(res.pvalue)
@@ -2066,17 +2056,17 @@ def permutation_test(data, statistic, *, permutation_type='independent',
2066
2056
"""
2067
2057
args = _permutation_test_iv (data , statistic , permutation_type , vectorized ,
2068
2058
n_resamples , batch , alternative , axis ,
2069
- random_state )
2059
+ rng )
2070
2060
(data , statistic , permutation_type , vectorized , n_resamples , batch ,
2071
- alternative , axis , random_state ) = args
2061
+ alternative , axis , rng ) = args
2072
2062
2073
2063
observed = statistic (* data , axis = - 1 )
2074
2064
2075
2065
null_calculators = {"pairings" : _calculate_null_pairings ,
2076
2066
"samples" : _calculate_null_samples ,
2077
2067
"independent" : _calculate_null_both }
2078
2068
null_calculator_args = (data , statistic , n_resamples ,
2079
- batch , random_state )
2069
+ batch , rng )
2080
2070
calculate_null = null_calculators [permutation_type ]
2081
2071
null_distribution , n_resamples , exact_test = (
2082
2072
calculate_null (* null_calculator_args ))
0 commit comments