Skip to content

Commit 581f65a

Browse files
committed
Adapt to Solve changes in Scipy 1.15
1. Use actual Solve Op to infer output dtype as CholSolve outputs a different dtype than basic Solve in Scipy==1.15 2. Tweaked test related to #1152 3. Tweak tolerage
1 parent cff058c commit 581f65a

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

pytensor/tensor/slinalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,10 @@ def make_node(self, A, b):
259259
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
260260

261261
# Infer dtype by solving the most simple case with 1x1 matrices
262-
o_dtype = scipy.linalg.solve(
263-
np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)
264-
).dtype
262+
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
263+
out_arr = [[None]]
264+
self.perform(None, inp_arr, out_arr)
265+
o_dtype = out_arr[0][0].dtype
265266
x = tensor(dtype=o_dtype, shape=b.type.shape)
266267
return Apply(self, [A, b], [x])
267268

tests/tensor/test_blockwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def core_scipy_fn(A, b):
590590
A_val_copy, b_val_copy
591591
)
592592
np.testing.assert_allclose(
593-
out, expected_out, atol=1e-5 if config.floatX == "float32" else 0
593+
out, expected_out, atol=1e-4 if config.floatX == "float32" else 0
594594
)
595595

596596
# Confirm input was destroyed

tests/tensor/test_slinalg.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,12 @@ def test_eigvalsh_grad():
169169
)
170170

171171

172-
class TestSolveBase(utt.InferShapeTester):
172+
class TestSolveBase:
173+
class SolveTest(SolveBase):
174+
def perform(self, node, inputs, outputs):
175+
A, b = inputs
176+
outputs[0][0] = scipy.linalg.solve(A, b)
177+
173178
@pytest.mark.parametrize(
174179
"A_func, b_func, error_message",
175180
[
@@ -191,16 +196,16 @@ def test_make_node(self, A_func, b_func, error_message):
191196
with pytest.raises(ValueError, match=error_message):
192197
A = A_func()
193198
b = b_func()
194-
SolveBase(b_ndim=2)(A, b)
199+
self.SolveTest(b_ndim=2)(A, b)
195200

196201
def test__repr__(self):
197202
np.random.default_rng(utt.fetch_seed())
198203
A = matrix()
199204
b = matrix()
200-
y = SolveBase(b_ndim=2)(A, b)
205+
y = self.SolveTest(b_ndim=2)(A, b)
201206
assert (
202207
y.__repr__()
203-
== "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
208+
== "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
204209
)
205210

206211

@@ -239,8 +244,9 @@ def test_correctness(self):
239244
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
240245
A_val = np.dot(A_val.transpose(), A_val)
241246

242-
assert np.allclose(
243-
scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val)
247+
np.testing.assert_allclose(
248+
scipy.linalg.solve(A_val, b_val, assume_a="gen"),
249+
gen_solve_func(A_val, b_val),
244250
)
245251

246252
A_undef = np.array(
@@ -253,7 +259,7 @@ def test_correctness(self):
253259
],
254260
dtype=config.floatX,
255261
)
256-
assert np.allclose(
262+
np.testing.assert_allclose(
257263
scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val)
258264
)
259265

@@ -450,7 +456,7 @@ def test_solve_dtype(self):
450456
fn = function([A, b], x)
451457
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
452458

453-
assert x.dtype == x_result.dtype
459+
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)
454460

455461

456462
def test_cho_solve():

0 commit comments

Comments
 (0)