@@ -362,6 +362,8 @@ def test_invalid_batched_a(self):
362
362
ids = ["block_diag" , "kron" ],
363
363
)
364
364
def test_local_lift_through_linalg (constructor , f_op , f , g_op , g ):
365
+ rng = np .random .default_rng (sum (map (ord , "lift_through_linalg" )))
366
+
365
367
if pytensor .config .floatX .endswith ("32" ):
366
368
pytest .skip ("Test is flaky at half precision" )
367
369
@@ -387,9 +389,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
387
389
assert len (f_ops ) == 2
388
390
assert len (g_ops ) == 1
389
391
390
- test_vals = [
391
- np .random .normal (size = (3 ,) * A .ndim ).astype (config .floatX ) for _ in range (2 )
392
- ]
392
+ test_vals = [rng .normal (size = (3 ,) * A .ndim ).astype (config .floatX ) for _ in range (2 )]
393
393
test_vals = [x @ np .swapaxes (x , - 1 , - 2 ) for x in test_vals ]
394
394
395
395
np .testing .assert_allclose (f1 (* test_vals ), f2 (* test_vals ), atol = 1e-8 )
@@ -466,6 +466,8 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
466
466
x_diag = pt .eye (1 , 1 ) * x
467
467
y = pt .linalg .det (x_diag )
468
468
f_rewritten = function ([x ], y , mode = "FAST_RUN" )
469
+ pytensor .dprint (f_rewritten )
470
+
469
471
nodes = f_rewritten .maker .fgraph .apply_nodes
470
472
471
473
assert any (isinstance (node .op , Det ) for node in nodes )
@@ -475,6 +477,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1():
475
477
x_test_matrix = np .eye (1 , 1 ) * x_test
476
478
det_val = np .linalg .det (x_test_matrix )
477
479
rewritten_val = f_rewritten (x_test )
480
+
478
481
assert_allclose (
479
482
det_val ,
480
483
rewritten_val ,
0 commit comments