Jacobians of chol solve and lu solve are different #14926
-
The following code compares the jacobian of two different ways of solving a system of linear equations w.r.t. the coefficients matrix. I get different results according to the method that I use to factorize the coefficients matrix. Is that normal? from jax import jacfwd, random
import jax.numpy as jnp
import jax.scipy as jsp
def solve_chol(A, b):
L = jsp.linalg.cho_factor(A)
return jsp.linalg.cho_solve(L, b)
def solve_lu(A, b):
LU = jsp.linalg.lu_factor(A)
return jsp.linalg.lu_solve(LU, b)
seed = 13
key = random.PRNGKey(seed=seed)
key, key_A, key_b = random.split(key, 3)
p = 3
A = random.normal(key=key_A, shape=(p, p))
A = A.T @ A + jnp.eye(p)
b = random.normal(key=key_b, shape=(p,))
# These two tests work
assert jnp.allclose(solve_chol(A, b), solve_lu(A, b))
assert jnp.allclose(solve_chol(A, b), jnp.linalg.solve(A, b))
# These two tests fail
assert jnp.allclose(jacfwd(solve_chol, 0)(A, b), jacfwd(solve_lu, 0)(A, b))
assert jnp.allclose(jacfwd(solve_chol, 0)(A, b), jacfwd(jnp.linalg.solve, 0)(A, b))
# This test work
assert jnp.allclose(jacfwd(solve_lu, 0)(A, b), jacfwd(jsp.linalg.solve, 0)(A, b)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
It's expected but surprising. It comes down to whether you think One way to make them agree is to make the functions you're calling functions on symmetric matrices (via orthogonal projection onto that subspace): from jax import jacfwd, random
import jax.numpy as jnp
import jax.scipy as jsp
def solve_chol(A, b):
A = (A + A.T) / 2. # NEW
L = jsp.linalg.cho_factor(A)
return jsp.linalg.cho_solve(L, b)
def solve_lu(A, b):
A = (A + A.T) / 2. # NEW
LU = jsp.linalg.lu_factor(A)
return jsp.linalg.lu_solve(LU, b)
seed = 13
key = random.PRNGKey(seed=seed)
key, key_A, key_b = random.split(key, 3)
p = 3
A = random.normal(key=key_A, shape=(p, p))
A = A.T @ A + jnp.eye(p)
b = random.normal(key=key_b, shape=(p,))
print(jacfwd(solve_chol, 0)(A, b))
print(jacfwd(solve_lu, 0)(A, b)) They could probably be made to agree as functions on just one particular triangle of the input, though that'd take a little more fiddling. WDYT? |
Beta Was this translation helpful? Give feedback.
It's expected but surprising. It comes down to whether you think
cholesky
represents a function on symmetric square matrices, or on the upper triangles of any square matrices. JAX's default convention is to choose the former. See this comment on #10815.One way to make them agree is to make the functions you're calling functions on symmetric matrices (via orthogonal projection onto that subspace):