@@ -59,7 +59,12 @@ def test_simple_minimize():
5959 minimized_x_val , success_val = f (a_val , c_val , 0.0 )
6060
6161 assert success_val
62- assert minimized_x_val == (2 * a_val * c_val )
62+ np .testing .assert_allclose (
63+ minimized_x_val ,
64+ 2 * a_val * c_val ,
65+ atol = 1e-8 if config .floatX == "float64" else 1e-6 ,
66+ rtol = 1e-8 if config .floatX == "float64" else 1e-6 ,
67+ )
6368
6469 def f (x , a , b ):
6570 objective = (x - a * b ) ** 2
@@ -82,7 +87,7 @@ def test_minimize_vector_x(method, jac, hess):
8287 def rosenbrock_shifted_scaled (x , a , b ):
8388 return (a * (x [1 :] - x [:- 1 ] ** 2 ) ** 2 + (1 - x [:- 1 ]) ** 2 ).sum () + b
8489
85- x = pt .dvector ("x" )
90+ x = pt .tensor ("x" , shape = ( None ,) )
8691 a = pt .scalar ("a" )
8792 b = pt .scalar ("b" )
8893
@@ -91,23 +96,30 @@ def rosenbrock_shifted_scaled(x, a, b):
9196 objective , x , method = method , jac = jac , hess = hess , optimizer_kwargs = {"tol" : 1e-16 }
9297 )
9398
94- a_val = 0.5
95- b_val = 1.0
96- x0 = np .zeros (5 ).astype (floatX )
97- x_star_val = minimized_x .eval ({a : a_val , b : b_val , x : x0 })
99+ fn = pytensor .function ([x , a , b ], [minimized_x , success ])
98100
99- assert success .eval ({a : a_val , b : b_val , x : x0 })
101+ a_val = np .array (0.5 , dtype = floatX )
102+ b_val = np .array (1.0 , dtype = floatX )
103+ x0 = np .zeros ((5 ,)).astype (floatX )
104+ x_star_val , success = fn (x0 , a_val , b_val )
105+
106+ assert success
100107
101108 np .testing .assert_allclose (
102- x_star_val , np .ones_like (x_star_val ), atol = 1e-6 , rtol = 1e-6
109+ x_star_val ,
110+ np .ones_like (x_star_val ),
111+ atol = 1e-8 if config .floatX == "float64" else 1e-3 ,
112+ rtol = 1e-8 if config .floatX == "float64" else 1e-3 ,
103113 )
104114
115+ assert x_star_val .dtype == floatX
116+
105117 def f (x , a , b ):
106118 objective = rosenbrock_shifted_scaled (x , a , b )
107119 out = minimize (objective , x )[0 ]
108120 return out
109121
110- utt .verify_grad (f , [x0 , a_val , b_val ], eps = 1e-6 )
122+ utt .verify_grad (f , [x0 , a_val , b_val ], eps = 1e-3 if floatX == "float32" else 1e- 6 )
111123
112124
113125@pytest .mark .parametrize (
@@ -130,7 +142,12 @@ def fn(x, a):
130142 solution , success = func (x0 , a_val )
131143
132144 assert success
133- np .testing .assert_allclose (solution , - 1.02986653 , atol = 1e-6 , rtol = 1e-6 )
145+ np .testing .assert_allclose (
146+ solution ,
147+ - 1.02986653 ,
148+ atol = 1e-8 if config .floatX == "float64" else 1e-6 ,
149+ rtol = 1e-8 if config .floatX == "float64" else 1e-6 ,
150+ )
134151
135152 def root_fn (x , a ):
136153 f = fn (x , a )
@@ -147,15 +164,20 @@ def fn(x, a):
147164 return x + 2 * a * pt .cos (x )
148165
149166 f = fn (x , a )
150- root_f , success = root (f , x )
167+ root_f , success = root (f , x , method = "lm" , optimizer_kwargs = { "tol" : 1e-8 } )
151168 func = pytensor .function ([x , a ], [root_f , success ])
152169
153170 x0 = 0.0
154171 a_val = 1.0
155172 solution , success = func (x0 , a_val )
156173
157174 assert success
158- np .testing .assert_allclose (solution , - 1.02986653 , atol = 1e-6 , rtol = 1e-6 )
175+ np .testing .assert_allclose (
176+ solution ,
177+ - 1.02986653 ,
178+ atol = 1e-8 if config .floatX == "float64" else 1e-6 ,
179+ rtol = 1e-8 if config .floatX == "float64" else 1e-6 ,
180+ )
159181
160182 def root_fn (x , a ):
161183 f = fn (x , a )
@@ -165,24 +187,27 @@ def root_fn(x, a):
165187
166188
167189def test_root_system_of_equations ():
168- x = pt .dvector ("x" )
169- a = pt .dvector ("a" )
170- b = pt .dvector ("b" )
190+ x = pt .tensor ("x" , shape = ( None ,) )
191+ a = pt .tensor ("a" , shape = ( None ,) )
192+ b = pt .tensor ("b" , shape = ( None ,) )
171193
172194 f = pt .stack ([a [0 ] * x [0 ] * pt .cos (x [1 ]) - b [0 ], x [0 ] * x [1 ] - a [1 ] * x [1 ] - b [1 ]])
173195
174- root_f , success = root (f , x )
196+ root_f , success = root (f , x , method = "lm" , optimizer_kwargs = { "tol" : 1e-8 } )
175197 func = pytensor .function ([x , a , b ], [root_f , success ])
176198
177- x0 = np .array ([1.0 , 1.0 ])
178- a_val = np .array ([1.0 , 1.0 ])
179- b_val = np .array ([4.0 , 5.0 ])
199+ x0 = np .array ([1.0 , 1.0 ], dtype = floatX )
200+ a_val = np .array ([1.0 , 1.0 ], dtype = floatX )
201+ b_val = np .array ([4.0 , 5.0 ], dtype = floatX )
180202 solution , success = func (x0 , a_val , b_val )
181203
182204 assert success
183205
184206 np .testing .assert_allclose (
185- solution , np .array ([6.50409711 , 0.90841421 ]), atol = 1e-6 , rtol = 1e-6
207+ solution ,
208+ np .array ([6.50409711 , 0.90841421 ]),
209+ atol = 1e-8 if config .floatX == "float64" else 1e-6 ,
210+ rtol = 1e-8 if config .floatX == "float64" else 1e-6 ,
186211 )
187212
188213 def root_fn (x , a , b ):
@@ -191,4 +216,6 @@ def root_fn(x, a, b):
191216 )
192217 return root (f , x )[0 ]
193218
194- utt .verify_grad (root_fn , [x0 , a_val , b_val ], eps = 1e-6 )
219+ utt .verify_grad (
220+ root_fn , [x0 , a_val , b_val ], eps = 1e-6 if floatX == "float64" else 1e-3
221+ )
0 commit comments