Skip to content

Commit 3367349

Browse files
authored
Merge pull request scipy#21845 from andyfaff/spec7
MAINT: SPEC-007 optimize.check_grad
2 parents 058dae1 + d7c1a19 commit 3367349

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

scipy/optimize/_optimize.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ._numdiff import approx_derivative
4040
from scipy._lib._util import getfullargspec_no_self as _getfullargspec
4141
from scipy._lib._util import (MapWrapper, check_random_state, _RichResult,
42-
_call_callback_maybe_halt)
42+
_call_callback_maybe_halt, _transition_to_rng)
4343
from scipy.optimize._differentiable_functions import ScalarFunction, FD_METHODS
4444
from scipy._lib._array_api import (array_namespace, xp_atleast_nd,
4545
xp_create_diagonal)
@@ -1031,9 +1031,10 @@ def approx_fprime(xk, f, epsilon=_epsilon, *args):
10311031
args=args, f0=f0)
10321032

10331033

1034+
@_transition_to_rng("seed", position_num=6)
10341035
def check_grad(func, grad, x0, *args, epsilon=_epsilon,
1035-
direction='all', seed=None):
1036-
"""Check the correctness of a gradient function by comparing it against a
1036+
direction='all', rng=None):
1037+
r"""Check the correctness of a gradient function by comparing it against a
10371038
(forward) finite-difference approximation of the gradient.
10381039
10391040
Parameters
@@ -1056,17 +1057,15 @@ def check_grad(func, grad, x0, *args, epsilon=_epsilon,
10561057
using `func`. By default it is ``'all'``, in which case, all
10571058
the one hot direction vectors are considered to check `grad`.
10581059
If `func` is a vector valued function then only ``'all'`` can be used.
1059-
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
1060-
If `seed` is None (or `np.random`), the `numpy.random.RandomState`
1061-
singleton is used.
1062-
If `seed` is an int, a new ``RandomState`` instance is used,
1063-
seeded with `seed`.
1064-
If `seed` is already a ``Generator`` or ``RandomState`` instance then
1065-
that instance is used.
1066-
Specify `seed` for reproducing the return value from this function.
1067-
The random numbers generated with this seed affect the random vector
1068-
along which gradients are computed to check ``grad``. Note that `seed`
1069-
is only used when `direction` argument is set to `'random'`.
1060+
rng : `numpy.random.Generator`, optional
1061+
Pseudorandom number generator state. When `rng` is None, a new
1062+
`numpy.random.Generator` is created using entropy from the
1063+
operating system. Types other than `numpy.random.Generator` are
1064+
passed to `numpy.random.default_rng` to instantiate a ``Generator``.
1065+
1066+
The random numbers generated affect the random vector along which gradients
1067+
are computed to check ``grad``. Note that `rng` is only used when `direction`
1068+
argument is set to `'random'`.
10701069
10711070
Returns
10721071
-------
@@ -1106,8 +1105,8 @@ def g(w, func, x0, v, *args):
11061105
if _grad.ndim > 1:
11071106
raise ValueError("'random' can only be used with scalar valued"
11081107
" func")
1109-
random_state = check_random_state(seed)
1110-
v = random_state.normal(0, 1, size=(x0.shape))
1108+
rng_gen = check_random_state(rng)
1109+
v = rng_gen.standard_normal(size=(x0.shape))
11111110
_args = (func, x0, v) + args
11121111
_func = g
11131112
vars = np.zeros((1,))

scipy/optimize/tests/test_optimize.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,22 @@ def der_expit(x):
5252

5353
r = optimize.check_grad(expit, der_expit, x0)
5454
assert_almost_equal(r, 0)
55+
# SPEC-007 leave one call with seed to check it still works
5556
r = optimize.check_grad(expit, der_expit, x0,
5657
direction='random', seed=1234)
5758
assert_almost_equal(r, 0)
5859

5960
r = optimize.check_grad(expit, der_expit, x0, epsilon=1e-6)
6061
assert_almost_equal(r, 0)
6162
r = optimize.check_grad(expit, der_expit, x0, epsilon=1e-6,
62-
direction='random', seed=1234)
63+
direction='random', rng=1234)
6364
assert_almost_equal(r, 0)
6465

6566
# Check if the epsilon parameter is being considered.
6667
r = abs(optimize.check_grad(expit, der_expit, x0, epsilon=1e-1) - 0)
6768
assert r > 1e-7
6869
r = abs(optimize.check_grad(expit, der_expit, x0, epsilon=1e-1,
69-
direction='random', seed=1234) - 0)
70+
direction='random', rng=1234) - 0)
7071
assert r > 1e-7
7172

7273
def x_sinx(x):
@@ -78,16 +79,16 @@ def der_x_sinx(x):
7879
x0 = np.arange(0, 2, 0.2)
7980

8081
r = optimize.check_grad(x_sinx, der_x_sinx, x0,
81-
direction='random', seed=1234)
82+
direction='random', rng=1234)
8283
assert_almost_equal(r, 0)
8384

8485
assert_raises(ValueError, optimize.check_grad,
8586
x_sinx, der_x_sinx, x0,
86-
direction='random_projection', seed=1234)
87+
direction='random_projection', rng=1234)
8788

8889
# checking can be done for derivatives of vector valued functions
8990
r = optimize.check_grad(himmelblau_grad, himmelblau_hess, himmelblau_x0,
90-
direction='all', seed=1234)
91+
direction='all', rng=1234)
9192
assert r < 5e-7
9293

9394

0 commit comments

Comments
 (0)