39
39
from ._numdiff import approx_derivative
40
40
from scipy ._lib ._util import getfullargspec_no_self as _getfullargspec
41
41
from scipy ._lib ._util import (MapWrapper , check_random_state , _RichResult ,
42
- _call_callback_maybe_halt )
42
+ _call_callback_maybe_halt , _transition_to_rng )
43
43
from scipy .optimize ._differentiable_functions import ScalarFunction , FD_METHODS
44
44
from scipy ._lib ._array_api import (array_namespace , xp_atleast_nd ,
45
45
xp_create_diagonal )
@@ -1031,9 +1031,10 @@ def approx_fprime(xk, f, epsilon=_epsilon, *args):
1031
1031
args = args , f0 = f0 )
1032
1032
1033
1033
1034
+ @_transition_to_rng ("seed" , position_num = 6 )
1034
1035
def check_grad (func , grad , x0 , * args , epsilon = _epsilon ,
1035
- direction = 'all' , seed = None ):
1036
- """Check the correctness of a gradient function by comparing it against a
1036
+ direction = 'all' , rng = None ):
1037
+ r """Check the correctness of a gradient function by comparing it against a
1037
1038
(forward) finite-difference approximation of the gradient.
1038
1039
1039
1040
Parameters
@@ -1056,17 +1057,15 @@ def check_grad(func, grad, x0, *args, epsilon=_epsilon,
1056
1057
using `func`. By default it is ``'all'``, in which case, all
1057
1058
the one hot direction vectors are considered to check `grad`.
1058
1059
If `func` is a vector valued function then only ``'all'`` can be used.
1059
- seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
1060
- If `seed` is None (or `np.random`), the `numpy.random.RandomState`
1061
- singleton is used.
1062
- If `seed` is an int, a new ``RandomState`` instance is used,
1063
- seeded with `seed`.
1064
- If `seed` is already a ``Generator`` or ``RandomState`` instance then
1065
- that instance is used.
1066
- Specify `seed` for reproducing the return value from this function.
1067
- The random numbers generated with this seed affect the random vector
1068
- along which gradients are computed to check ``grad``. Note that `seed`
1069
- is only used when `direction` argument is set to `'random'`.
1060
+ rng : `numpy.random.Generator`, optional
1061
+ Pseudorandom number generator state. When `rng` is None, a new
1062
+ `numpy.random.Generator` is created using entropy from the
1063
+ operating system. Types other than `numpy.random.Generator` are
1064
+ passed to `numpy.random.default_rng` to instantiate a ``Generator``.
1065
+
1066
+ The random numbers generated affect the random vector along which gradients
1067
+ are computed to check ``grad``. Note that `rng` is only used when `direction`
1068
+ argument is set to `'random'`.
1070
1069
1071
1070
Returns
1072
1071
-------
@@ -1106,8 +1105,8 @@ def g(w, func, x0, v, *args):
1106
1105
if _grad .ndim > 1 :
1107
1106
raise ValueError ("'random' can only be used with scalar valued"
1108
1107
" func" )
1109
- random_state = check_random_state (seed )
1110
- v = random_state . normal ( 0 , 1 , size = (x0 .shape ))
1108
+ rng_gen = check_random_state (rng )
1109
+ v = rng_gen . standard_normal ( size = (x0 .shape ))
1111
1110
_args = (func , x0 , v ) + args
1112
1111
_func = g
1113
1112
vars = np .zeros ((1 ,))
0 commit comments