Skip to content

Added cholesky inverse operation to all the backends #21554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions keras/src/backend/jax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ def cholesky(a):
pass
return out

def cholesky_inverse(a):
identity = jnp.eye(a.shape[-1], dtype=a.dtype)
a_inv = solve_triangular(a, identity, lower=True)
out = jnp.matmul(jnp.transpose(a_inv), a_inv)
try:
if jnp.any(jnp.isnan(out)):
raise ValueError(
"Cholesky inverse failed. The input might not be a valid "
"positive definite matrix."
)
except jax.errors.TracerBoolConversionError:
pass
return out

def det(a):
return jnp.linalg.det(a)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
def cholesky(a):
return np.linalg.cholesky(a)

def cholesky_inverse(a):
identity = np.eye(a.shape[-1], dtype=a.dtype)
a_inv = solve_triangular(a, identity, lower=True)
out = solve_triangular(np.transpose(a), a_inv, lower=False)
return out

def det(a):
return np.linalg.det(a)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/openvino/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ def cholesky(a):
"`cholesky` is not supported with openvino backend"
)

def cholesky_inverse(a):
raise NotImplementedError(
"`Cholesky inverse` is not supported with the OpenVINO backend"
)


def det(a):
raise NotImplementedError("`det` is not supported with openvino backend")
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def cholesky(a):
# tf.linalg.cholesky simply returns NaNs for non-positive definite matrices
return tf.debugging.check_numerics(out, "Cholesky")

def cholesky_inverse(a):
identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype)
a_inv = solve_triangular(a, identity, lower=True)
out = tf.matmul(a_inv, a_inv, transpose_a=True)
return out

def det(a):
return tf.linalg.det(a)
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def cholesky(x):
return torch.linalg.cholesky(x)


def cholesky_inverse(x):
return torch.cholesky_inverse(x)


def det(x):
return torch.det(x)

Expand Down
42 changes: 42 additions & 0 deletions keras/src/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,48 @@ def _cholesky(x):
raise ValueError(f"Cholesky decomposition failed: {e}")


class CholeskyInverse(Operation):
def call(self, x):
return _cholesky_inverse(x)

def compute_output_spec(self, x):
_assert_2d(x)
_assert_square(x)
return KerasTensor(x.shape, x.dtype)


def cholesky_inverse(x):
"""Computes the inverse of a symmetric positive-definite matrix using the
Cholesky decomposition.

This function is more efficient and numerically stable than `keras.ops.inv`
for symmetric positive-definite matrices.

Args:
x: Input tensor of shape `(..., M, M)`. The matrix must be symmetric and
positive-definite.

Returns:
A tensor of shape `(..., M, M)` representing the inverse of `x`.

Raises:
ValueError: If `x` is not a symmetric positive-definite matrix.
"""
if any_symbolic_tensors((x,)):
return CholeskyInverse().symbolic_call(x)
return _cholesky_inverse(x)


def _cholesky_inverse(x):
x = backend.convert_to_tensor(x)
_assert_2d(x)
_assert_square(x)
try:
return backend.linalg.cholesky_inverse(x)
except Exception as e:
raise ValueError(f"Cholesky inverse failed: {e}")


class Det(Operation):
def call(self, x):
return _det(x)
Expand Down
32 changes: 32 additions & 0 deletions keras/src/ops/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def test_cholesky(self):
with self.assertRaises(ValueError):
linalg.cholesky(x)

def test_cholesky_inverse(self):
x = KerasTensor([None, 20, 20])
out = linalg.cholesky_inverse(x)
self.assertEqual(out.shape, (None, 20, 20))

x = KerasTensor([None, None, 20])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

x = KerasTensor([None, 20, 15])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

def test_det(self):
x = KerasTensor([None, 20, 20])
out = linalg.det(x)
Expand Down Expand Up @@ -196,6 +209,15 @@ def test_cholesky(self):
with self.assertRaises(ValueError):
linalg.cholesky(x)

def test_cholesky_inverse(self):
x = KerasTensor([4, 3, 3])
out = linalg.cholesky_inverse(x)
self.assertEqual(out.shape, (4, 3, 3))

x = KerasTensor([10, 20, 15])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

def test_det(self):
x = KerasTensor([4, 3, 3])
out = linalg.det(x)
Expand Down Expand Up @@ -338,6 +360,16 @@ def test_cholesky(self):
out = linalg.cholesky(x_psd)
self.assertAllClose(out, np.linalg.cholesky(x_psd), atol=1e-4)

def test_cholesky_inverse(self):
x_np = np.random.rand(3, 3).astype("float32")
x_psd_np = x_np @ x_np.T + 1e-4 * np.eye(3, dtype="float32")
x = linalg.cholesky(x_psd_np)
result_from_op = linalg.cholesky_inverse(x)
ground_truth_np = np.linalg.inv(x_psd_np)
self.assertAllClose(
result_from_op, ground_truth_np, atol=1e-5, rtol=1e-5
)

def test_det(self):
x = np.random.rand(4, 3, 3)
out = linalg.det(x)
Expand Down
Loading