Skip to content

Commit 08b8c3c

Browse files
Add more tests
1 parent 21c10a1 commit 08b8c3c

File tree

1 file changed

+41
-8
lines changed

1 file changed

+41
-8
lines changed

tests/tensor/test_optimize.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import numpy as np
2+
13
import pytensor.tensor as pt
4+
from pytensor import config
25
from pytensor.tensor.optimize import minimize
6+
from tests import unittest_tools as utt
7+
38

9+
floatX = config.floatX
410

5-
def test_minimize():
11+
12+
def test_simple_minimize():
613
x = pt.scalar("x")
714
a = pt.scalar("a")
815
c = pt.scalar("c")
@@ -11,16 +18,42 @@ def test_minimize():
1118
b.name = "b"
1219
out = (x - b * c) ** 2
1320

14-
minimized_x, success = minimize(out, x, debug=False)
21+
minimized_x, success = minimize(out, x)
1522

16-
a_val = 2
17-
c_val = 3
23+
a_val = 2.0
24+
c_val = 3.0
1825

1926
assert success
2027
assert minimized_x.eval({a: a_val, c: c_val, x: 0.0}) == (2 * a_val * c_val)
2128

22-
x_grad, a_grad, c_grad = pt.grad(minimized_x, [x, a, c])
29+
def f(x, a, b):
30+
objective = (x - a * b) ** 2
31+
out = minimize(objective, x)[0]
32+
return out
33+
34+
utt.verify_grad(f, [0.0, a_val, c_val], eps=1e-6)
35+
36+
37+
def test_minimize_vector_x():
38+
def rosenbrock_shifted_scaled(x, a, b):
39+
return (a * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum() + b
40+
41+
x = pt.dvector("x")
42+
a = pt.scalar("a")
43+
b = pt.scalar("b")
44+
45+
objective = rosenbrock_shifted_scaled(x, a, b)
46+
47+
minimized_x, success = minimize(objective, x, method="BFGS")
48+
49+
a_val = 0.5
50+
b_val = 1.0
51+
x0 = np.zeros(5).astype(floatX)
52+
x_star_val = minimized_x.eval({a: a_val, b: b_val, x: x0})
53+
54+
assert success
55+
np.testing.assert_allclose(
56+
x_star_val, np.ones_like(x_star_val), atol=1e-6, rtol=1e-6
57+
)
2358

24-
assert x_grad.eval({x: 0.0}) == 0.0
25-
assert a_grad.eval({a: a_val, c: c_val, x: 0.0}) == 2 * c_val
26-
assert c_grad.eval({a: a_val, c: c_val, x: 0.0}) == 2 * a_val
59+
utt.verify_grad(rosenbrock_shifted_scaled, [x0, a_val, b_val], eps=1e-6)

0 commit comments

Comments
 (0)