Skip to content

Commit 19f2895

Browse files
Seed test_local_lift_through_linalg test
1 parent c038109 commit 19f2895

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/tensor/rewriting/test_linalg.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ def test_invalid_batched_a(self):
362362
ids=["block_diag", "kron"],
363363
)
364364
def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
365+
rng = np.random.default_rng(sum(map(ord, "lift_through_linalg")))
366+
365367
if pytensor.config.floatX.endswith("32"):
366368
pytest.skip("Test is flaky at half precision")
367369

@@ -387,9 +389,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
387389
assert len(f_ops) == 2
388390
assert len(g_ops) == 1
389391

390-
test_vals = [
391-
np.random.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)
392-
]
392+
test_vals = [rng.normal(size=(3,) * A.ndim).astype(config.floatX) for _ in range(2)]
393393
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
394394

395395
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
@@ -466,6 +466,8 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
466466
x_diag = pt.eye(1, 1) * x
467467
y = pt.linalg.det(x_diag)
468468
f_rewritten = function([x], y, mode="FAST_RUN")
469+
pytensor.dprint(f_rewritten)
470+
469471
nodes = f_rewritten.maker.fgraph.apply_nodes
470472

471473
assert any(isinstance(node.op, Det) for node in nodes)
@@ -475,6 +477,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
475477
x_test_matrix = np.eye(1, 1) * x_test
476478
det_val = np.linalg.det(x_test_matrix)
477479
rewritten_val = f_rewritten(x_test)
480+
478481
assert_allclose(
479482
det_val,
480483
rewritten_val,

0 commit comments

Comments
 (0)