Skip to content

Commit 6c7f1ba

Browse files
Fix float32 tests
1 parent f9a7865 commit 6c7f1ba

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

pytensor/tensor/slinalg.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,8 @@ def make_node(self, A, B):
792792
def perform(self, node, inputs, output_storage):
793793
(A, B) = inputs
794794
X = output_storage[0]
795-
796-
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B)
795+
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
796+
X[0] = scipy.linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
797797

798798
def infer_shape(self, fgraph, node, shapes):
799799
return [shapes[0]]
@@ -830,7 +830,10 @@ def perform(self, node, inputs, output_storage):
830830
(A, B) = inputs
831831
X = output_storage[0]
832832

833-
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear")
833+
dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
834+
X[0] = scipy.linalg.solve_discrete_lyapunov(A, B, method="bilinear").astype(
835+
dtype
836+
)
834837

835838
def infer_shape(self, fgraph, node, shapes):
836839
return [shapes[0]]

tests/tensor/test_slinalg.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,33 @@ def test_solve_discrete_lyapunov(
543543
x = solve_discrete_lyapunov(a, q, method=method)
544544
f = function([a, q], x)
545545

546-
A = rng.normal(size=shape)
547-
Q = rng.normal(size=shape)
546+
A = rng.normal(size=shape).astype(dtype)
547+
Q = rng.normal(size=shape).astype(dtype)
548548

549549
X = f(A, Q)
550550
Q_recovered = vec_recover_Q(A, X, continuous=False)
551-
np.testing.assert_allclose(Q_recovered, Q)
551+
552+
atol = rtol = 1e-4 if config.floatX == "float32" else 1e-8
553+
np.testing.assert_allclose(Q_recovered, Q, atol=atol, rtol=rtol)
554+
555+
556+
@pytest.mark.parametrize("use_complex", [False, True], ids=["float", "complex"])
557+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
558+
@pytest.mark.parametrize("method", ["direct", "bilinear"])
559+
def test_solve_discrete_lyapunov_gradient(
560+
use_complex, shape: tuple[int], method: Literal["direct", "bilinear"]
561+
):
562+
if config.floatX == "float32":
563+
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
564+
565+
rng = np.random.default_rng(utt.fetch_seed())
566+
dtype = config.floatX
567+
if use_complex:
568+
precision = int(dtype[-2:]) # 64 or 32
569+
dtype = f"complex{int(2 * precision)}"
570+
571+
A = rng.normal(size=shape).astype(dtype)
572+
Q = rng.normal(size=shape).astype(dtype)
552573

553574
utt.verify_grad(
554575
functools.partial(solve_discrete_lyapunov, method=method),
@@ -564,13 +585,25 @@ def test_solve_continuous_lyapunov(shape: tuple[int]):
564585
q = pt.tensor(name="q", shape=shape)
565586
f = function([a, q], [solve_continuous_lyapunov(a, q)])
566587

567-
A = rng.normal(size=shape)
568-
Q = rng.normal(size=shape)
588+
A = rng.normal(size=shape).astype(config.floatX)
589+
Q = rng.normal(size=shape).astype(config.floatX)
569590
X = f(A, Q)
570591

571592
Q_recovered = vec_recover_Q(A, X, continuous=True)
572593

573-
np.testing.assert_allclose(Q_recovered.squeeze(), Q)
594+
atol = rtol = 1e-2 if config.floatX == "float32" else 1e-8
595+
np.testing.assert_allclose(Q_recovered.squeeze(), Q, atol=atol, rtol=rtol)
596+
597+
598+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batched"])
599+
def test_solve_continuous_lyapunov_grad(shape: tuple[int]):
600+
if config.floatX == "float32":
601+
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
602+
603+
rng = np.random.default_rng(utt.fetch_seed())
604+
A = rng.normal(size=shape).astype(config.floatX)
605+
Q = rng.normal(size=shape).astype(config.floatX)
606+
574607
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)
575608

576609

0 commit comments

Comments
 (0)