Skip to content

Commit dba3d18

Browse files
committed
add check for k = 0
1 parent 8c16f65 commit dba3d18

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,12 @@ def _find_solve_with_eye(node):
552552
# Check whether second input to solve is Eye
553553
solve_inputs = node.inputs
554554
potential_eye_input = solve_inputs[1].owner
555+
555556
if not (potential_eye_input and isinstance(potential_eye_input.op, Eye)):
556557
return False
557558

559+
if getattr(potential_eye_input.inputs[-1], "data", -1).item() != 0:
560+
return False
558561
return True
559562

560563

tests/tensor/rewriting/test_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,5 +564,5 @@ def get_pt_function(x, op_name):
564564
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
565565

566566
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
567-
x_testing = np.random.rand(10, 10)
567+
x_testing = np.random.rand(10, 10).astype(config.floatX)
568568
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)

0 commit comments

Comments
 (0)