Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions keras/src/backend/jax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ def cholesky(a):
pass
return out

def cholesky_inverse(a):
L = cholesky(a)
identity = jnp.eye(a.shape[0], dtype=a.dtype)
L_inv = solve_triangular(L, identity, lower=True)
out = L_inv.T @ L_inv
try:
if jnp.any(jnp.isnan(out)):
raise ValueError(
"Cholesky inverse failed. The input might not be a valid "
"positive definite matrix."
)
except jax.errors.TracerBoolConversionError:
pass
return out

def det(a):
return jnp.linalg.det(a)
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
def cholesky(a):
return np.linalg.cholesky(a)

def cholesky_inverse(a):
L = np.linalg.cholesky(a)
identity = np.eye(a.shape[0], dtype=a.dtype)
L_inv = np.linalg.solve(L, identity)
a_inv = L_inv.T @ L_inv
return a_inv

def det(a):
return np.linalg.det(a)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/openvino/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ def cholesky(a):
"`cholesky` is not supported with openvino backend"
)

def cholesky_inverse(a):
raise NotImplementedError(
"`cholesky inverse` is not supported with openvino backend"
)


def det(a):
raise NotImplementedError("`det` is not supported with openvino backend")
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ def cholesky(a):
# tf.linalg.cholesky simply returns NaNs for non-positive definite matrices
return tf.debugging.check_numerics(out, "Cholesky")

def cholesky_inverse(a):
L = cholesky(a)
identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype)
L_inv = solve_triangular(L, identity, lower=True)
a_inv = tf.matmul(L_inv, L_inv, transpose_a=True)
return a_inv

def det(a):
return tf.linalg.det(a)
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
def cholesky(x):
return torch.linalg.cholesky(x)

def cholesky_inverse(x):
return torch.cholesky_inverse(x)

def det(x):
return torch.det(x)
Expand Down
27 changes: 27 additions & 0 deletions keras/src/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,33 @@ def _cholesky(x):
raise ValueError(f"Cholesky decomposition failed: {e}")


class CholeskyInverse(Operation):
def call(self, x):
return _cholesky_inverse(x)

def compute_output_spec(self, x):
_assert_2d(x)
_assert_square(x)
return KerasTensor(x.shape, x.dtype)


@keras_export(["keras.ops.cholesky_inverse", "keras.ops.linalg.cholesky_inverse"])
def cholesky_inverse(x):
if any_symbolic_tensors((x,)):
return CholeskyInverse().symbolic_call(x)
return _cholesky_inverse(x)


def _cholesky_inverse(x):
x = backend.convert_to_tensor(x)
_assert_2d(x)
_assert_square(x)
try:
return backend.linalg.cholesky_inverse(x)
except Exception as e:
raise ValueError(f"Cholesky inverse failed: {e}")


class Det(Operation):
def call(self, x):
return _det(x)
Expand Down
13 changes: 13 additions & 0 deletions keras/src/ops/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def test_cholesky(self):
with self.assertRaises(ValueError):
linalg.cholesky(x)

def test_cholesky_inverse(self):
x = KerasTensor([None, 20, 20])
out = linalg.cholesky_inverse(x)
self.assertEqual(out.shape, (None, 20, 20))

x = KerasTensor([None, None, 20])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

x = KerasTensor([None, 20, 15])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

def test_det(self):
x = KerasTensor([None, 20, 20])
out = linalg.det(x)
Expand Down
Loading