Skip to content

Commit 1a89309

Browse files
Use rewrite_mode defined in test_math.py for testing
1 parent 14ee8e2 commit 1a89309

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4765,15 +4765,15 @@ def test_local_block_diag_dot_to_dot_block_diag(left_multiply):
47654765
else:
47664766
out = d @ x
47674767

4768-
fn = pytensor.function([a, b, c, d], out)
4768+
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
47694769
assert not any(
47704770
isinstance(node.op, BlockDiagonal) for node in fn.maker.fgraph.toposort()
47714771
)
47724772

47734773
fn_expected = pytensor.function(
47744774
[a, b, c, d],
47754775
out,
4776-
mode=get_default_mode().excluding("local_block_diag_dot_to_dot_block_diag"),
4776+
mode=rewrite_mode.excluding("local_block_diag_dot_to_dot_block_diag"),
47774777
)
47784778

47794779
rng = np.random.default_rng()

0 commit comments

Comments
 (0)