Skip to content

Commit 3bfe589

Browse files
committed
MAINT: SPEC-007 optimize.check_grad
1 parent fb88f55 commit 3bfe589

File tree

2 files changed

+37
-18
lines changed

2 files changed

+37
-18
lines changed

scipy/optimize/_optimize.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from ._numdiff import approx_derivative
4040
from scipy._lib._util import getfullargspec_no_self as _getfullargspec
4141
from scipy._lib._util import (MapWrapper, check_random_state, _RichResult,
42-
_call_callback_maybe_halt)
42+
_call_callback_maybe_halt, _transition_to_rng)
4343
from scipy.optimize._differentiable_functions import ScalarFunction, FD_METHODS
4444
from scipy._lib._array_api import (array_namespace, xp_atleast_nd,
4545
xp_create_diagonal)
@@ -1031,9 +1031,10 @@ def approx_fprime(xk, f, epsilon=_epsilon, *args):
10311031
args=args, f0=f0)
10321032

10331033

1034+
@_transition_to_rng("seed", position_num=6, replace_doc=False)
10341035
def check_grad(func, grad, x0, *args, epsilon=_epsilon,
1035-
direction='all', seed=None):
1036-
"""Check the correctness of a gradient function by comparing it against a
1036+
direction='all', rng=None):
1037+
r"""Check the correctness of a gradient function by comparing it against a
10371038
(forward) finite-difference approximation of the gradient.
10381039
10391040
Parameters
@@ -1056,14 +1057,31 @@ def check_grad(func, grad, x0, *args, epsilon=_epsilon,
10561057
using `func`. By default it is ``'all'``, in which case, all
10571058
the one hot direction vectors are considered to check `grad`.
10581059
If `func` is a vector valued function then only ``'all'`` can be used.
1059-
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
1060-
If `seed` is None (or `np.random`), the `numpy.random.RandomState`
1061-
singleton is used.
1062-
If `seed` is an int, a new ``RandomState`` instance is used,
1063-
seeded with `seed`.
1064-
If `seed` is already a ``Generator`` or ``RandomState`` instance then
1065-
that instance is used.
1066-
Specify `seed` for reproducing the return value from this function.
1060+
rng : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
1061+
If `rng` is passed by keyword, types other than `numpy.random.Generator` are
1062+
passed to `numpy.random.default_rng` to instantiate a ``Generator``.
1063+
If `rng` is already a ``Generator`` instance, then the provided instance is
1064+
used. Specify `rng` for repeatable function behavior.
1065+
1066+
If this argument is passed by position or `seed` is passed by keyword,
1067+
legacy behavior for the argument `seed` applies:
1068+
1069+
- If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
1070+
singleton is used.
1071+
- If `seed` is an int, a new ``RandomState`` instance is used,
1072+
seeded with `seed`.
1073+
- If `seed` is already a ``Generator`` or ``RandomState`` instance then
1074+
that instance is used.
1075+
1076+
.. versionchanged:: 1.15.0
1077+
As part of the `SPEC-007 <https://scientific-python.org/specs/spec-0007/>`_
1078+
transition from use of `numpy.random.RandomState` to
1079+
`numpy.random.Generator`, this keyword was changed from `seed` to `rng`.
1080+
For an interim period, both keywords will continue to work, although only one
1081+
may be specified at a time. After the interim period, function calls using the
1082+
`seed` keyword will emit warnings. The behavior of both `seed` and
1083+
`rng` are outlined above, but only the `rng` keyword should be used in new code.
1084+
10671085
The random numbers generated with this seed affect the random vector
10681086
along which gradients are computed to check ``grad``. Note that `seed`
10691087
is only used when `direction` argument is set to `'random'`.
@@ -1106,8 +1124,8 @@ def g(w, func, x0, v, *args):
11061124
if _grad.ndim > 1:
11071125
raise ValueError("'random' can only be used with scalar valued"
11081126
" func")
1109-
random_state = check_random_state(seed)
1110-
v = random_state.normal(0, 1, size=(x0.shape))
1127+
rng_gen = check_random_state(rng)
1128+
v = rng_gen.standard_normal(size=(x0.shape))
11111129
_args = (func, x0, v) + args
11121130
_func = g
11131131
vars = np.zeros((1,))

scipy/optimize/tests/test_optimize.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,22 @@ def der_expit(x):
5252

5353
r = optimize.check_grad(expit, der_expit, x0)
5454
assert_almost_equal(r, 0)
55+
# SPEC-007 leave one call with seed to check it still works
5556
r = optimize.check_grad(expit, der_expit, x0,
5657
direction='random', seed=1234)
5758
assert_almost_equal(r, 0)
5859

5960
r = optimize.check_grad(expit, der_expit, x0, epsilon=1e-6)
6061
assert_almost_equal(r, 0)
6162
r = optimize.check_grad(expit, der_expit, x0, epsilon=1e-6,
62-
direction='random', seed=1234)
63+
direction='random', rng=1234)
6364
assert_almost_equal(r, 0)
6465

6566
# Check if the epsilon parameter is being considered.
6667
r = abs(optimize.check_grad(expit, der_expit, x0, epsilon=1e-1) - 0)
6768
assert r > 1e-7
6869
r = abs(optimize.check_grad(expit, der_expit, x0, epsilon=1e-1,
69-
direction='random', seed=1234) - 0)
70+
direction='random', rng=1234) - 0)
7071
assert r > 1e-7
7172

7273
def x_sinx(x):
@@ -78,16 +79,16 @@ def der_x_sinx(x):
7879
x0 = np.arange(0, 2, 0.2)
7980

8081
r = optimize.check_grad(x_sinx, der_x_sinx, x0,
81-
direction='random', seed=1234)
82+
direction='random', rng=1234)
8283
assert_almost_equal(r, 0)
8384

8485
assert_raises(ValueError, optimize.check_grad,
8586
x_sinx, der_x_sinx, x0,
86-
direction='random_projection', seed=1234)
87+
direction='random_projection', rng=1234)
8788

8889
# checking can be done for derivatives of vector valued functions
8990
r = optimize.check_grad(himmelblau_grad, himmelblau_hess, himmelblau_x0,
90-
direction='all', seed=1234)
91+
direction='all', rng=1234)
9192
assert r < 5e-7
9293

9394

0 commit comments

Comments
 (0)