Skip to content

Commit 646e5f8

Browse files
authored
MAINT: spatial.transform.Rotation.random: transition to RNG (SPEC 7) (scipy#21914)
1 parent d26dffc commit 646e5f8

File tree

2 files changed

+56
-58
lines changed

2 files changed

+56
-58
lines changed

scipy/spatial/transform/_rotation.pyx

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import warnings
55
import numpy as np
6-
from scipy._lib._util import check_random_state
6+
from scipy._lib._util import check_random_state, _transition_to_rng
77
from ._rotation_groups import create_group
88

99
cimport numpy as np
@@ -3185,23 +3185,20 @@ cdef class Rotation:
31853185

31863186
@cython.embedsignature(True)
31873187
@classmethod
3188-
def random(cls, num=None, random_state=None):
3188+
@_transition_to_rng('random_state', position_num=2)
3189+
def random(cls, num=None, rng=None):
31893190
"""Generate uniformly distributed rotations.
31903191
31913192
Parameters
31923193
----------
31933194
num : int or None, optional
31943195
Number of random rotations to generate. If None (default), then a
31953196
single rotation is generated.
3196-
random_state : {None, int, `numpy.random.Generator`,
3197-
`numpy.random.RandomState`}, optional
3198-
3199-
If `seed` is None (or `np.random`), the `numpy.random.RandomState`
3200-
singleton is used.
3201-
If `seed` is an int, a new ``RandomState`` instance is used,
3202-
seeded with `seed`.
3203-
If `seed` is already a ``Generator`` or ``RandomState`` instance
3204-
then that instance is used.
3197+
rng : `numpy.random.Generator`, optional
3198+
Pseudorandom number generator state. When `rng` is None, a new
3199+
`numpy.random.Generator` is created using entropy from the
3200+
operating system. Types other than `numpy.random.Generator` are
3201+
passed to `numpy.random.default_rng` to instantiate a `Generator`.
32053202
32063203
Returns
32073204
-------
@@ -3238,12 +3235,12 @@ cdef class Rotation:
32383235
scipy.stats.special_ortho_group
32393236
32403237
"""
3241-
random_state = check_random_state(random_state)
3238+
rng = check_random_state(rng)
32423239

32433240
if num is None:
3244-
sample = random_state.normal(size=4)
3241+
sample = rng.normal(size=4)
32453242
else:
3246-
sample = random_state.normal(size=(num, 4))
3243+
sample = rng.normal(size=(num, 4))
32473244

32483245
return cls(sample)
32493246

scipy/spatial/transform/tests/test_rotation.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ def test_as_euler_degenerate_compare_algorithms(seq_tuple, intrinsic):
882882
def test_inv():
883883
rnd = np.random.RandomState(0)
884884
n = 10
885+
# preserve use of old random_state during SPEC 7 transition
885886
p = Rotation.random(num=n, random_state=rnd)
886887
q = p.inv()
887888

@@ -898,8 +899,8 @@ def test_inv():
898899

899900

900901
def test_inv_single_rotation():
901-
rnd = np.random.RandomState(0)
902-
p = Rotation.random(random_state=rnd)
902+
rng = np.random.default_rng(146972845698875399755764481408308808739)
903+
p = Rotation.random(rng=rng)
903904
q = p.inv()
904905

905906
p_mat = p.as_matrix()
@@ -912,7 +913,7 @@ def test_inv_single_rotation():
912913
assert_array_almost_equal(res1, eye)
913914
assert_array_almost_equal(res2, eye)
914915

915-
x = Rotation.random(num=1, random_state=rnd)
916+
x = Rotation.random(num=1, rng=rng)
916917
y = x.inv()
917918

918919
x_matrix = x.as_matrix()
@@ -940,7 +941,7 @@ def test_single_identity_magnitude():
940941

941942
def test_identity_invariance():
942943
n = 10
943-
p = Rotation.random(n, random_state=0)
944+
p = Rotation.random(n, rng=0)
944945

945946
result = p * Rotation.identity(n)
946947
assert_array_almost_equal(p.as_quat(), result.as_quat())
@@ -951,7 +952,7 @@ def test_identity_invariance():
951952

952953
def test_single_identity_invariance():
953954
n = 10
954-
p = Rotation.random(n, random_state=0)
955+
p = Rotation.random(n, rng=0)
955956

956957
result = p * Rotation.identity()
957958
assert_array_almost_equal(p.as_quat(), result.as_quat())
@@ -980,9 +981,9 @@ def test_magnitude_single_rotation():
980981

981982

982983
def test_approx_equal():
983-
rng = np.random.RandomState(0)
984-
p = Rotation.random(10, random_state=rng)
985-
q = Rotation.random(10, random_state=rng)
984+
rng = np.random.default_rng(146972845698875399755764481408308808739)
985+
p = Rotation.random(10, rng=rng)
986+
q = Rotation.random(10, rng=rng)
986987
r = p * q.inv()
987988
r_mag = r.magnitude()
988989
atol = np.median(r_mag) # ensure we get mix of Trues and Falses
@@ -1046,10 +1047,10 @@ def test_reduction_none_indices():
10461047

10471048

10481049
def test_reduction_scalar_calculation():
1049-
rng = np.random.RandomState(0)
1050-
l = Rotation.random(5, random_state=rng)
1051-
r = Rotation.random(10, random_state=rng)
1052-
p = Rotation.random(7, random_state=rng)
1050+
rng = np.random.default_rng(146972845698875399755764481408308808739)
1051+
l = Rotation.random(5, rng=rng)
1052+
r = Rotation.random(10, rng=rng)
1053+
p = Rotation.random(7, rng=rng)
10531054
reduced, left_best, right_best = p.reduce(l, r, return_indices=True)
10541055

10551056
# Loop implementation of the vectorized calculation in Rotation.reduce
@@ -1201,23 +1202,23 @@ def test_setitem_single():
12011202

12021203

12031204
def test_setitem_slice():
1204-
rng = np.random.RandomState(seed=0)
1205-
r1 = Rotation.random(10, random_state=rng)
1206-
r2 = Rotation.random(5, random_state=rng)
1205+
rng = np.random.default_rng(146972845698875399755764481408308808739)
1206+
r1 = Rotation.random(10, rng=rng)
1207+
r2 = Rotation.random(5, rng=rng)
12071208
r1[1:6] = r2
12081209
assert_equal(r1[1:6].as_quat(), r2.as_quat())
12091210

12101211

12111212
def test_setitem_integer():
1212-
rng = np.random.RandomState(seed=0)
1213-
r1 = Rotation.random(10, random_state=rng)
1214-
r2 = Rotation.random(random_state=rng)
1213+
rng = np.random.default_rng(146972845698875399755764481408308808739)
1214+
r1 = Rotation.random(10, rng=rng)
1215+
r2 = Rotation.random(rng=rng)
12151216
r1[1] = r2
12161217
assert_equal(r1[1].as_quat(), r2.as_quat())
12171218

12181219

12191220
def test_setitem_wrong_type():
1220-
r = Rotation.random(10, random_state=0)
1221+
r = Rotation.random(10, rng=0)
12211222
with pytest.raises(TypeError, match='Rotation object'):
12221223
r[0] = 1
12231224

@@ -1241,12 +1242,12 @@ def test_n_rotations():
12411242

12421243

12431244
def test_random_rotation_shape():
1244-
rnd = np.random.RandomState(0)
1245-
assert_equal(Rotation.random(random_state=rnd).as_quat().shape, (4,))
1246-
assert_equal(Rotation.random(None, random_state=rnd).as_quat().shape, (4,))
1245+
rng = np.random.default_rng(146972845698875399755764481408308808739)
1246+
assert_equal(Rotation.random(rng=rng).as_quat().shape, (4,))
1247+
assert_equal(Rotation.random(None, rng=rng).as_quat().shape, (4,))
12471248

1248-
assert_equal(Rotation.random(1, random_state=rnd).as_quat().shape, (1, 4))
1249-
assert_equal(Rotation.random(5, random_state=rnd).as_quat().shape, (5, 4))
1249+
assert_equal(Rotation.random(1, rng=rng).as_quat().shape, (1, 4))
1250+
assert_equal(Rotation.random(5, rng=rng).as_quat().shape, (5, 4))
12501251

12511252

12521253
def test_align_vectors_no_rotation():
@@ -1259,9 +1260,9 @@ def test_align_vectors_no_rotation():
12591260

12601261

12611262
def test_align_vectors_no_noise():
1262-
rnd = np.random.RandomState(0)
1263-
c = Rotation.random(random_state=rnd)
1264-
b = rnd.normal(size=(5, 3))
1263+
rng = np.random.default_rng(14697284569885399755764481408308808739)
1264+
c = Rotation.random(rng=rng)
1265+
b = rng.normal(size=(5, 3))
12651266
a = c.apply(b)
12661267

12671268
est, rssd = Rotation.align_vectors(a, b)
@@ -1296,8 +1297,8 @@ def test_align_vectors_rssd_sensitivity():
12961297

12971298
def test_align_vectors_scaled_weights():
12981299
n = 10
1299-
a = Rotation.random(n, random_state=0).apply([1, 0, 0])
1300-
b = Rotation.random(n, random_state=1).apply([1, 0, 0])
1300+
a = Rotation.random(n, rng=0).apply([1, 0, 0])
1301+
b = Rotation.random(n, rng=1).apply([1, 0, 0])
13011302
scale = 2
13021303

13031304
est1, rssd1, cov1 = Rotation.align_vectors(a, b, np.ones(n), True)
@@ -1309,17 +1310,17 @@ def test_align_vectors_scaled_weights():
13091310

13101311

13111312
def test_align_vectors_noise():
1312-
rnd = np.random.RandomState(0)
1313+
rng = np.random.default_rng(146972845698875399755764481408308808739)
13131314
n_vectors = 100
1314-
rot = Rotation.random(random_state=rnd)
1315-
vectors = rnd.normal(size=(n_vectors, 3))
1315+
rot = Rotation.random(rng=rng)
1316+
vectors = rng.normal(size=(n_vectors, 3))
13161317
result = rot.apply(vectors)
13171318

13181319
# The paper adds noise as independently distributed angular errors
13191320
sigma = np.deg2rad(1)
13201321
tolerance = 1.5 * sigma
13211322
noise = Rotation.from_rotvec(
1322-
rnd.normal(
1323+
rng.normal(
13231324
size=(n_vectors, 3),
13241325
scale=sigma
13251326
)
@@ -1433,7 +1434,7 @@ def test_align_vectors_near_inf():
14331434
n = 100
14341435
mats = []
14351436
for i in range(6):
1436-
mats.append(Rotation.random(n, random_state=10 + i).as_matrix())
1437+
mats.append(Rotation.random(n, rng=10 + i).as_matrix())
14371438

14381439
for i in range(n):
14391440
# Get random pairs of 3-element vectors
@@ -1491,7 +1492,7 @@ def test_align_vectors_antiparallel():
14911492
assert_allclose(R.apply(b[0]), a[0], atol=atol)
14921493

14931494
# Test exact rotations near 180 deg
1494-
Rs = Rotation.random(100, random_state=0)
1495+
Rs = Rotation.random(100, rng=0)
14951496
dRs = Rotation.from_rotvec(Rs.as_rotvec()*1e-4) # scale down to small angle
14961497
a = [[ 1, 0, 0], [0, 1, 0]]
14971498
b = [[-1, 0, 0], [0, 1, 0]]
@@ -1506,8 +1507,8 @@ def test_align_vectors_antiparallel():
15061507

15071508
def test_align_vectors_primary_only():
15081509
atol = 1e-12
1509-
mats_a = Rotation.random(100, random_state=0).as_matrix()
1510-
mats_b = Rotation.random(100, random_state=1).as_matrix()
1510+
mats_a = Rotation.random(100, rng=0).as_matrix()
1511+
mats_b = Rotation.random(100, rng=1).as_matrix()
15111512
for mat_a, mat_b in zip(mats_a, mats_b):
15121513
# Get random 3-element unit vectors
15131514
a = mat_a[0]
@@ -1658,16 +1659,16 @@ def test_slerp_call_scalar_time():
16581659

16591660

16601661
def test_multiplication_stability():
1661-
qs = Rotation.random(50, random_state=0)
1662-
rs = Rotation.random(1000, random_state=1)
1662+
qs = Rotation.random(50, rng=0)
1663+
rs = Rotation.random(1000, rng=1)
16631664
for q in qs:
16641665
rs *= q * rs
16651666
assert_allclose(np.linalg.norm(rs.as_quat(), axis=1), 1)
16661667

16671668

16681669
def test_pow():
16691670
atol = 1e-14
1670-
p = Rotation.random(10, random_state=0)
1671+
p = Rotation.random(10, rng=0)
16711672
p_inv = p.inv()
16721673
# Test the short-cuts and other integers
16731674
for n in [-5, -2, -1, 0, 1, 2, 5]:
@@ -1703,14 +1704,14 @@ def test_pow():
17031704

17041705

17051706
def test_pow_errors():
1706-
p = Rotation.random(random_state=0)
1707+
p = Rotation.random(rng=0)
17071708
with pytest.raises(NotImplementedError, match='modulus not supported'):
17081709
pow(p, 1, 1)
17091710

17101711

17111712
def test_rotation_within_numpy_array():
1712-
single = Rotation.random(random_state=0)
1713-
multiple = Rotation.random(2, random_state=1)
1713+
single = Rotation.random(rng=0)
1714+
multiple = Rotation.random(2, rng=1)
17141715

17151716
array = np.array(single)
17161717
assert_equal(array.shape, ())
@@ -1762,7 +1763,7 @@ def test_as_euler_contiguous():
17621763

17631764

17641765
def test_concatenate():
1765-
rotation = Rotation.random(10, random_state=0)
1766+
rotation = Rotation.random(10, rng=0)
17661767
sizes = [1, 2, 3, 1, 3]
17671768
starts = [0] + list(np.cumsum(sizes))
17681769
split = [rotation[i:i + n] for i, n in zip(starts, sizes)]

0 commit comments

Comments
 (0)