diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index ff1c74cdec..9ea8db37fc 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -14,16 +14,22 @@ from pytensor.tensor.variable import TensorVariable -def decompose_A(A, assume_a): +def decompose_A(A, assume_a, check_finite): if assume_a == "gen": - return lu_factor(A, check_finite=False) + return lu_factor(A, check_finite=check_finite) else: raise NotImplementedError -def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False): - if assume_a == "gen": - return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed) +def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): + if core_solve_op.assume_a == "gen": + return lu_solve( + A_decomp, + b, + trans=transposed, + b_ndim=core_solve_op.b_ndim, + check_finite=core_solve_op.check_finite, + ) else: raise NotImplementedError @@ -102,14 +108,19 @@ def find_solve_clients(var, assume_a): ): return None - A_decomp = decompose_A(A, assume_a=assume_a) + # If any Op had check_finite=True, we also do it for the LU decomposition + check_finite_decomp = False + for client, _ in A_solve_clients_and_transpose: + if client.op.core_op.check_finite: + check_finite_decomp = True + break + A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp) replacements = {} for client, transposed in A_solve_clients_and_transpose: _, b = client.inputs - b_ndim = client.op.core_op.b_ndim new_x = solve_lu_decomposed_system( - A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed + A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op ) [old_x] = client.outputs new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 0474aad77b..5ae92006e2 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -793,7 +793,7 @@ def tensor( try: # Help catching errors with the new tensor API # Many single letter strings are valid sctypes - if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): + if str(name) == "floatX" or (len(str(name)) > 2 and np.dtype(name).type): raise ValueError( f"The first and only positional argument of tensor is now `name`. Got {name}.\n" "This name looks like a dtype, which you should pass as a keyword argument only." diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index 6f04fac5fb..32683029f0 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -161,3 +161,36 @@ def test_lu_decomposition_reused_scan(transposed): resx1 = fn_opt(A_test, x0_test) rtol = 1e-7 if config.floatX == "float64" else 1e-6 np.testing.assert_allclose(resx0, resx1, rtol=rtol) + + +def test_lu_decomposition_reused_preserves_check_finite(): + # Check that the LU decomposition rewrite preserves the check_finite flag + rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ + + A = tensor("A", shape=(2, 2)) + b1 = tensor("b1", shape=(2,)) + b2 = tensor("b2", shape=(2,)) + + x1 = solve(A, b1, assume_a="gen", check_finite=True) + x2 = solve(A, b2, assume_a="gen", check_finite=False) + fn_opt = function( + [A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name) + ) + opt_nodes = fn_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(opt_nodes) == 0 + assert count_lu_decom_nodes(opt_nodes) == 1 + assert count_lu_solve_nodes(opt_nodes) == 2 + + # We should get an error if A or b1 is non finite + A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype) + b1_valid = np.array([1, 1], dtype=b1.type.dtype) + b2_valid = np.array([1, 1], dtype=b2.type.dtype) + + assert fn_opt(A_valid, b1_valid, b2_valid) # Fine + assert fn_opt( + A_valid, b1_valid, b2_valid * np.nan + ) # Should not raise (also fine on most LAPACK implementations?) + with pytest.raises(ValueError, match="array must not contain infs or NaNs"): + assert fn_opt(A_valid, b1_valid * np.nan, b2_valid) + with pytest.raises(ValueError, match="array must not contain infs or NaNs"): + assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)