Skip to content

Added cholesky inverse operation to all the backends #21554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from keras.src.ops.core import while_loop as while_loop
from keras.src.ops.einops import rearrange as rearrange
from keras.src.ops.linalg import cholesky as cholesky
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
from keras.src.ops.linalg import det as det
from keras.src.ops.linalg import eig as eig
from keras.src.ops.linalg import eigh as eigh
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from keras.src.ops.linalg import cholesky as cholesky
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
from keras.src.ops.linalg import det as det
from keras.src.ops.linalg import eig as eig
from keras.src.ops.linalg import eigh as eigh
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from keras.src.ops.core import while_loop as while_loop
from keras.src.ops.einops import rearrange as rearrange
from keras.src.ops.linalg import cholesky as cholesky
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
from keras.src.ops.linalg import det as det
from keras.src.ops.linalg import eig as eig
from keras.src.ops.linalg import eigh as eigh
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from keras.src.ops.linalg import cholesky as cholesky
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
from keras.src.ops.linalg import det as det
from keras.src.ops.linalg import eig as eig
from keras.src.ops.linalg import eigh as eigh
Expand Down
14 changes: 12 additions & 2 deletions keras/src/backend/jax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from keras.src.backend.jax.core import convert_to_tensor


def cholesky(a):
out = jnp.linalg.cholesky(a)
def cholesky(a, upper=False):
out = jnp.linalg.cholesky(a, upper=upper)
try:
# In eager mode, raise for nan to
# achieve behavior consistency with numpy
Expand All @@ -26,6 +26,16 @@ def cholesky(a):
return out


def cholesky_inverse(a, upper=False):
identity = jnp.eye(a.shape[-1], dtype=a.dtype)
inv_chol = solve_triangular(a, identity, lower=not upper)
if upper:
a_inv = jnp.matmul(inv_chol, jnp.transpose(inv_chol))
else:
a_inv = jnp.matmul(jnp.transpose(inv_chol), inv_chol)
return a_inv


def det(a):
return jnp.linalg.det(a)

Expand Down
14 changes: 12 additions & 2 deletions keras/src/backend/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,18 @@
from keras.src.backend.numpy.core import convert_to_tensor


def cholesky(a):
return np.linalg.cholesky(a)
def cholesky(a, upper=False):
return np.linalg.cholesky(a, upper=upper)


def cholesky_inverse(a, upper=False):
identity = np.eye(a.shape[-1], dtype=a.dtype)
inv_chol = solve_triangular(a, identity, lower=not upper)
if upper:
a_inv = np.matmul(inv_chol, inv_chol.T)
else:
a_inv = np.matmul(inv_chol.T, inv_chol)
return a_inv


def det(a):
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,5 @@ TestMathErrors::test_istft_invalid_window_shape_2D_inputs
TestMathErrors::test_stft_invalid_input_type
TestMathErrors::test_stft_invalid_window
TestMathErrors::test_stft_invalid_window_shape
LinalgOpsCorrectnessTest::test_cholesky
LinalgOpsCorrectnessTest::test_cholesky_inverse
10 changes: 8 additions & 2 deletions keras/src/backend/openvino/linalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
def cholesky(a):
def cholesky(a, upper=False):
raise NotImplementedError(
"`cholesky` is not supported with openvino backend"
"`cholesky` is not supported with openvino backend."
)


def cholesky_inverse(a, upper=False):
raise NotImplementedError(
"`cholesky_inverse` is not supported with openvino backend."
)


Expand Down
17 changes: 15 additions & 2 deletions keras/src/backend/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,23 @@
from keras.src.backend.tensorflow.core import convert_to_tensor


def cholesky(a):
def cholesky(a, upper=False):
out = tf.linalg.cholesky(a)
# tf.linalg.cholesky simply returns NaNs for non-positive definite matrices
return tf.debugging.check_numerics(out, "Cholesky")
out = tf.debugging.check_numerics(out, "Cholesky")
if upper:
return tf.linalg.adjoint(out)
return out


def cholesky_inverse(a, upper=False):
identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype)
inv_chol = tf.linalg.triangular_solve(a, identity, lower=not upper)
if upper:
a_inv = tf.matmul(inv_chol, inv_chol, transpose_b=True)
else:
a_inv = tf.matmul(inv_chol, inv_chol, transpose_a=True)
return a_inv


def det(a):
Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
from keras.src.backend.torch.core import convert_to_tensor


def cholesky(x):
return torch.linalg.cholesky(x)
def cholesky(x, upper=False):
return torch.linalg.cholesky(x, upper=upper)


def cholesky_inverse(x, upper=False):
return torch.cholesky_inverse(x, upper=upper)


def det(x):
Expand Down
68 changes: 59 additions & 9 deletions keras/src/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@


class Cholesky(Operation):
def __init__(self, upper=False, *, name=None):
super().__init__(name=name)
self.upper = upper

def call(self, x):
return _cholesky(x)
return _cholesky(x, self.upper)

def compute_output_spec(self, x):
_assert_2d(x)
Expand All @@ -17,32 +21,78 @@ def compute_output_spec(self, x):


@keras_export(["keras.ops.cholesky", "keras.ops.linalg.cholesky"])
def cholesky(x):
def cholesky(x, upper=False):
"""Computes the Cholesky decomposition of a positive semi-definite matrix.

Args:
x: Input tensor of shape `(..., M, M)`.
upper (bool): If True, returns the upper-triangular Cholesky factor.
If False (default), returns the lower-triangular Cholesky factor.

Returns:
A tensor of shape `(..., M, M)` representing the lower triangular
Cholesky factor of `x`.

A tensor of shape `(..., M, M)` representing the Cholesky factor of `x`.
"""
if any_symbolic_tensors((x,)):
return Cholesky().symbolic_call(x)
return _cholesky(x)
return Cholesky(upper=upper).symbolic_call(x)
return _cholesky(x, upper=upper)


def _cholesky(x):
def _cholesky(x, upper=False):
x = backend.convert_to_tensor(x)
_assert_2d(x)
_assert_square(x)
try:
return backend.linalg.cholesky(x)
return backend.linalg.cholesky(x, upper=upper)
except Exception as e:
raise ValueError(f"Cholesky decomposition failed: {e}")


class CholeskyInverse(Operation):
def __init__(self, upper=False, *, name=None):
super().__init__(name=name)
self.upper = upper

def call(self, x):
return _cholesky_inverse(x, self.upper)

def compute_output_spec(self, x):
_assert_2d(x)
_assert_square(x)
return KerasTensor(x.shape, x.dtype)


@keras_export(
["keras.ops.cholesky_inverse", "keras.ops.linalg.cholesky_inverse"]
)
def cholesky_inverse(x, upper=False):
"""Computes the inverse of a symmetric positive-definite matrix.

Args:
x: Input tensor of shape `(..., M, M)`.
upper (bool): Determines whether to use the upper- or lower-triangular
factor for the internal computation. Defaults to False.

Returns:
A tensor of shape `(..., M, M)` representing the inverse of `x`.

Raises:
ValueError: If `x` is not a symmetric positive-definite matrix.
"""
if any_symbolic_tensors((x,)):
return CholeskyInverse(upper=upper).symbolic_call(x)
return _cholesky_inverse(x, upper=upper)


def _cholesky_inverse(x, upper=False):
x = backend.convert_to_tensor(x)
_assert_2d(x)
_assert_square(x)
try:
return backend.linalg.cholesky_inverse(x, upper=upper)
except Exception as e:
raise ValueError(f"Cholesky inverse failed: {e}")


class Det(Operation):
def call(self, x):
return _det(x)
Expand Down
72 changes: 67 additions & 5 deletions keras/src/ops/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def test_cholesky(self):
with self.assertRaises(ValueError):
linalg.cholesky(x)

def test_cholesky_inverse(self):
x = KerasTensor([None, 20, 20])
out = linalg.cholesky_inverse(x)
self.assertEqual(out.shape, (None, 20, 20))

x = KerasTensor([None, None, 20])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

x = KerasTensor([None, 20, 15])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

def test_det(self):
x = KerasTensor([None, 20, 20])
out = linalg.det(x)
Expand Down Expand Up @@ -196,6 +209,15 @@ def test_cholesky(self):
with self.assertRaises(ValueError):
linalg.cholesky(x)

def test_cholesky_inverse(self):
x = KerasTensor([4, 3, 3])
out = linalg.cholesky_inverse(x)
self.assertEqual(out.shape, (4, 3, 3))

x = KerasTensor([10, 20, 15])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

def test_det(self):
x = KerasTensor([4, 3, 3])
out = linalg.det(x)
Expand Down Expand Up @@ -331,12 +353,52 @@ def test_svd(self):

class LinalgOpsCorrectnessTest(testing.TestCase):
def test_cholesky(self):
x = np.random.rand(4, 3, 3).astype("float32")
x_non_psd = np.random.rand(4, 3, 3).astype("float32")
with self.assertRaises(ValueError):
linalg.cholesky(x)
x_psd = x @ x.transpose((0, 2, 1)) + 1e-5 * np.eye(3)
out = linalg.cholesky(x_psd)
self.assertAllClose(out, np.linalg.cholesky(x_psd), atol=1e-4)
linalg.cholesky(x_non_psd)

x = np.random.rand(4, 3, 3).astype("float32")
x_psd = np.matmul(x, x.transpose((0, 2, 1))) + 1e-5 * np.eye(
3, dtype="float32"
)

l_out = linalg.cholesky(x_psd, upper=False)
l_expected = np.linalg.cholesky(x_psd)
self.assertAllClose(l_out, l_expected, atol=1e-4)

u_out = linalg.cholesky(x_psd, upper=True)
u_expected = l_expected.transpose((0, 2, 1))
self.assertAllClose(u_out, u_expected, atol=1e-4)

@parameterized.named_parameters(
{"testcase_name": "lower", "upper": False},
{"testcase_name": "upper", "upper": True},
)
def test_cholesky_inverse(self, upper):
A = np.array(
[
[4.0, 12.0, -16.0],
[12.0, 37.0, -43.0],
[-16.0, -43.0, 98.0],
],
dtype="float32",
)
if upper:
factor = np.linalg.cholesky(A, upper=True)
else:
factor = np.linalg.cholesky(A)

expected_inverse = np.array(
[
[49.36111, -13.555555, 2.111111],
[-13.555555, 3.777778, -0.555556],
[2.111111, -0.555556, 0.111111],
],
dtype="float32",
)

output_inverse = linalg.cholesky_inverse(factor, upper=upper)
self.assertAllClose(output_inverse, expected_inverse, atol=1e-5)

def test_det(self):
x = np.random.rand(4, 3, 3)
Expand Down