File tree Expand file tree Collapse file tree 2 files changed +7
-7
lines changed
Expand file tree Collapse file tree 2 files changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -1770,14 +1770,9 @@ def verify_grad(
17701770 if rel_tol is None :
17711771 rel_tol = max (_type_tol [str (p .dtype )] for p in pt )
17721772
1773+ # Initialize RNG if not provided
17731774 if rng is None :
1774- raise TypeError (
1775- "rng should be a valid instance of "
1776- "numpy.random.RandomState. You may "
1777- "want to use tests.unittest"
1778- "_tools.verify_grad instead of "
1779- "pytensor.gradient.verify_grad."
1780- )
1775+ rng = np .random .default_rng ()
17811776
17821777 # We allow input downcast in `function`, because `numeric_grad` works in
17831778 # the most precise dtype used among the inputs, so we may need to cast
Original file line number Diff line number Diff line change 33from scipy .optimize import rosen_hess_prod
44
55import pytensor
6+ import pytensor .tensor as pt
67import pytensor .tensor .basic as ptb
78from pytensor .configdefaults import config
89from pytensor .gradient import (
@@ -602,6 +603,10 @@ def test_grad_constant(self):
602603 + str (g_one )
603604 )
604605
606+ def test_verify_grad_no_rng (self ):
607+ """Test `verify_grad` works without requiring an explicit RNG."""
608+ utt .verify_grad (pt .log , [2.0 ])
609+
605610
606611def test_known_grads ():
607612 # Tests that the grad method with no known_grads
You can’t perform that action at this time.
0 commit comments