Skip to content

Commit cb3d7e7

Browse files
Changes to support float32 inputs
1 parent 33889f0 commit cb3d7e7

File tree

2 files changed

+67
-34
lines changed

2 files changed

+67
-34
lines changed

pytensor/tensor/optimize.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def __init__(self, fn, copy_x: bool = False):
5252
self.last_result = None
5353
self.copy_x = copy_x
5454

55+
# Scipy does not respect dtypes *at all*, so we have to force it ourselves.
56+
self.dtype = fn.maker.fgraph.inputs[0].type.dtype
57+
5558
self.cache_hits = 0
5659
self.cache_misses = 0
5760

@@ -67,9 +70,7 @@ def __call__(self, x, *args):
6770
If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
6871
new input and result.
6972
"""
70-
# scipy.optimize.scalar_minimize and scalar_root don't take initial values as an argument, so we can't control
71-
# the first input to the inner function. Of course, they use a scalar, but we need a 0d numpy array.
72-
x = np.asarray(x)
73+
x = x.astype(self.dtype)
7374

7475
if self.last_result is None or not (x == self.last_x).all():
7576
self.cache_misses += 1
@@ -160,6 +161,7 @@ def _get_parameter_grads_from_vector(
160161
)
161162

162163
grad_wrt_args.append(dot(output_grad, arg_grad))
164+
163165
cursor += arg_size
164166

165167
return grad_wrt_args
@@ -175,17 +177,11 @@ def build_fn(self):
175177
"""
176178
outputs = self.inner_outputs
177179
self._fn = fn = function(self.inner_inputs, outputs, trust_input=True)
180+
178181
# Do this reassignment to see the compiled graph in the dprint
179182
# self.fgraph = fn.maker.fgraph
180183

181-
if self.inner_inputs[0].type.shape == ():
182-
183-
def fn_wrapper(x, *args):
184-
return fn(x.squeeze(), *args)
185-
186-
self._fn_wrapped = LRUCache1(fn_wrapper)
187-
else:
188-
self._fn_wrapped = LRUCache1(fn)
184+
self._fn_wrapped = LRUCache1(fn)
189185

190186
@property
191187
def fn(self):
@@ -771,7 +767,9 @@ def perform(self, node, inputs, outputs):
771767
**self.optimizer_kwargs,
772768
)
773769

774-
outputs[0][0] = res.x.reshape(variables.shape)
770+
# There's a reshape here to cover the case where variables is a scalar. Scipy will still return a
771+
# (1, 1) matrix in in this case, which causes errors downstream (since pytensor expects a scalar).
772+
outputs[0][0] = res.x.reshape(variables.shape).astype(variables.dtype)
775773
outputs[1][0] = np.bool_(res.success)
776774

777775
def L_op(
@@ -807,12 +805,20 @@ def root(
807805
variables: TensorVariable,
808806
method: str = "hybr",
809807
jac: bool = True,
808+
optimizer_kwargs: dict | None = None,
810809
):
811810
"""Find roots of a system of equations using scipy.optimize.root."""
812811

813812
args = _find_optimization_parameters(equations, variables)
814813

815-
root_op = RootOp(variables, *args, equations=equations, method=method, jac=jac)
814+
root_op = RootOp(
815+
variables,
816+
*args,
817+
equations=equations,
818+
method=method,
819+
jac=jac,
820+
optimizer_kwargs=optimizer_kwargs,
821+
)
816822

817823
return root_op(variables, *args)
818824

tests/tensor/test_optimize.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ def test_simple_minimize():
5959
minimized_x_val, success_val = f(a_val, c_val, 0.0)
6060

6161
assert success_val
62-
assert minimized_x_val == (2 * a_val * c_val)
62+
np.testing.assert_allclose(
63+
minimized_x_val,
64+
2 * a_val * c_val,
65+
atol=1e-8 if config.floatX == "float64" else 1e-6,
66+
rtol=1e-8 if config.floatX == "float64" else 1e-6,
67+
)
6368

6469
def f(x, a, b):
6570
objective = (x - a * b) ** 2
@@ -82,7 +87,7 @@ def test_minimize_vector_x(method, jac, hess):
8287
def rosenbrock_shifted_scaled(x, a, b):
8388
return (a * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum() + b
8489

85-
x = pt.dvector("x")
90+
x = pt.tensor("x", shape=(None,))
8691
a = pt.scalar("a")
8792
b = pt.scalar("b")
8893

@@ -91,23 +96,30 @@ def rosenbrock_shifted_scaled(x, a, b):
9196
objective, x, method=method, jac=jac, hess=hess, optimizer_kwargs={"tol": 1e-16}
9297
)
9398

94-
a_val = 0.5
95-
b_val = 1.0
96-
x0 = np.zeros(5).astype(floatX)
97-
x_star_val = minimized_x.eval({a: a_val, b: b_val, x: x0})
99+
fn = pytensor.function([x, a, b], [minimized_x, success])
98100

99-
assert success.eval({a: a_val, b: b_val, x: x0})
101+
a_val = np.array(0.5, dtype=floatX)
102+
b_val = np.array(1.0, dtype=floatX)
103+
x0 = np.zeros((5,)).astype(floatX)
104+
x_star_val, success = fn(x0, a_val, b_val)
105+
106+
assert success
100107

101108
np.testing.assert_allclose(
102-
x_star_val, np.ones_like(x_star_val), atol=1e-6, rtol=1e-6
109+
x_star_val,
110+
np.ones_like(x_star_val),
111+
atol=1e-8 if config.floatX == "float64" else 1e-3,
112+
rtol=1e-8 if config.floatX == "float64" else 1e-3,
103113
)
104114

115+
assert x_star_val.dtype == floatX
116+
105117
def f(x, a, b):
106118
objective = rosenbrock_shifted_scaled(x, a, b)
107119
out = minimize(objective, x)[0]
108120
return out
109121

110-
utt.verify_grad(f, [x0, a_val, b_val], eps=1e-6)
122+
utt.verify_grad(f, [x0, a_val, b_val], eps=1e-3 if floatX == "float32" else 1e-6)
111123

112124

113125
@pytest.mark.parametrize(
@@ -130,7 +142,12 @@ def fn(x, a):
130142
solution, success = func(x0, a_val)
131143

132144
assert success
133-
np.testing.assert_allclose(solution, -1.02986653, atol=1e-6, rtol=1e-6)
145+
np.testing.assert_allclose(
146+
solution,
147+
-1.02986653,
148+
atol=1e-8 if config.floatX == "float64" else 1e-6,
149+
rtol=1e-8 if config.floatX == "float64" else 1e-6,
150+
)
134151

135152
def root_fn(x, a):
136153
f = fn(x, a)
@@ -147,15 +164,20 @@ def fn(x, a):
147164
return x + 2 * a * pt.cos(x)
148165

149166
f = fn(x, a)
150-
root_f, success = root(f, x)
167+
root_f, success = root(f, x, method="lm", optimizer_kwargs={"tol": 1e-8})
151168
func = pytensor.function([x, a], [root_f, success])
152169

153170
x0 = 0.0
154171
a_val = 1.0
155172
solution, success = func(x0, a_val)
156173

157174
assert success
158-
np.testing.assert_allclose(solution, -1.02986653, atol=1e-6, rtol=1e-6)
175+
np.testing.assert_allclose(
176+
solution,
177+
-1.02986653,
178+
atol=1e-8 if config.floatX == "float64" else 1e-6,
179+
rtol=1e-8 if config.floatX == "float64" else 1e-6,
180+
)
159181

160182
def root_fn(x, a):
161183
f = fn(x, a)
@@ -165,24 +187,27 @@ def root_fn(x, a):
165187

166188

167189
def test_root_system_of_equations():
168-
x = pt.dvector("x")
169-
a = pt.dvector("a")
170-
b = pt.dvector("b")
190+
x = pt.tensor("x", shape=(None,))
191+
a = pt.tensor("a", shape=(None,))
192+
b = pt.tensor("b", shape=(None,))
171193

172194
f = pt.stack([a[0] * x[0] * pt.cos(x[1]) - b[0], x[0] * x[1] - a[1] * x[1] - b[1]])
173195

174-
root_f, success = root(f, x)
196+
root_f, success = root(f, x, method="lm", optimizer_kwargs={"tol": 1e-8})
175197
func = pytensor.function([x, a, b], [root_f, success])
176198

177-
x0 = np.array([1.0, 1.0])
178-
a_val = np.array([1.0, 1.0])
179-
b_val = np.array([4.0, 5.0])
199+
x0 = np.array([1.0, 1.0], dtype=floatX)
200+
a_val = np.array([1.0, 1.0], dtype=floatX)
201+
b_val = np.array([4.0, 5.0], dtype=floatX)
180202
solution, success = func(x0, a_val, b_val)
181203

182204
assert success
183205

184206
np.testing.assert_allclose(
185-
solution, np.array([6.50409711, 0.90841421]), atol=1e-6, rtol=1e-6
207+
solution,
208+
np.array([6.50409711, 0.90841421]),
209+
atol=1e-8 if config.floatX == "float64" else 1e-6,
210+
rtol=1e-8 if config.floatX == "float64" else 1e-6,
186211
)
187212

188213
def root_fn(x, a, b):
@@ -191,4 +216,6 @@ def root_fn(x, a, b):
191216
)
192217
return root(f, x)[0]
193218

194-
utt.verify_grad(root_fn, [x0, a_val, b_val], eps=1e-6)
219+
utt.verify_grad(
220+
root_fn, [x0, a_val, b_val], eps=1e-6 if floatX == "float64" else 1e-3
221+
)

0 commit comments

Comments
 (0)