Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion keras/src/backend/jax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from keras.src.backend.jax.core import convert_to_tensor


def cholesky(a):
def cholesky(a, upper=False):
out = jnp.linalg.cholesky(a)
try:
# In eager mode, raise for nan to
Expand All @@ -23,9 +23,22 @@ def cholesky(a):
except jax.errors.TracerBoolConversionError:
# Cannot raise for nan in tracing mode
pass
if upper:
return jnp.swapaxes(out, -2, -1)
return out


def cholesky_inverse(a, upper=False):
identity = jnp.eye(a.shape[-1], dtype=a.dtype)
if upper:
u_inv = solve_triangular(a, identity, lower=False)
a_inv = jnp.matmul(u_inv, jnp.transpose(u_inv))
else:
l_inv = solve_triangular(a, identity, lower=True)
a_inv = jnp.matmul(jnp.transpose(l_inv), l_inv)
return a_inv


def det(a):
return jnp.linalg.det(a)

Expand Down
15 changes: 14 additions & 1 deletion keras/src/backend/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
from keras.src.backend.numpy.core import convert_to_tensor


def cholesky(a):
def cholesky(a, upper=False):
if upper:
return np.linalg.cholesky(a, upper=True)
return np.linalg.cholesky(a)


def cholesky_inverse(a, upper=False):
identity = np.eye(a.shape[-1], dtype=a.dtype)
if upper:
u_inv = solve_triangular(a, identity, lower=False)
a_inv = np.matmul(u_inv, u_inv.T)
else:
l_inv = solve_triangular(a, identity, lower=True)
a_inv = np.matmul(l_inv.T, l_inv)
return a_inv


def det(a):
return np.linalg.det(a)

Expand Down
8 changes: 7 additions & 1 deletion keras/src/backend/openvino/linalg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
def cholesky(a):
def cholesky(a, upper=False):
raise NotImplementedError(
"`cholesky` is not supported with openvino backend"
)


def cholesky_inverse(a, upper=False):
raise NotImplementedError(
"`cholesky_inverse` is not supported with openvino backend"
)


def det(a):
raise NotImplementedError("`det` is not supported with openvino backend")

Expand Down
18 changes: 16 additions & 2 deletions keras/src/backend/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,24 @@
from keras.src.backend.tensorflow.core import convert_to_tensor


def cholesky(a):
def cholesky(a, upper=True):
out = tf.linalg.cholesky(a)
# tf.linalg.cholesky simply returns NaNs for non-positive definite matrices
return tf.debugging.check_numerics(out, "Cholesky")
out = tf.debugging.check_numerics(out, "Cholesky")
if upper:
return tf.linalg.matrix_transpose(out)
return out


def cholesky_inverse(a, upper=True):
identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype)
if upper:
u_inv = tf.linalg.triangular_solve(a, identity, lower=False)
a_inv = tf.matmul(u_inv, u_inv, transpose_b=True)
else:
l_inv = tf.linalg.triangular_solve(a, identity, lower=True)
a_inv = tf.matmul(l_inv, l_inv, transpose_a=True)
return a_inv


def det(a):
Expand Down
10 changes: 9 additions & 1 deletion keras/src/backend/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@
from keras.src.backend.torch.core import convert_to_tensor


def cholesky(x):
def cholesky(x, upper=False):
if upper:
return torch.linalg.cholesky(x, upper=True)
return torch.linalg.cholesky(x)


def cholesky_inverse(x, upper=False):
if upper:
return torch.cholesky_inverse(x, upper=True)
return torch.cholesky_inverse(x)


def det(x):
return torch.det(x)

Expand Down
62 changes: 51 additions & 11 deletions keras/src/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,82 @@


class Cholesky(Operation):
def call(self, x):
return _cholesky(x)
def call(self, x, upper=False):
return _cholesky(x, upper)

def compute_output_spec(self, x):
def compute_output_spec(self, x, upper=False):
_assert_2d(x)
_assert_square(x)
return KerasTensor(x.shape, x.dtype)


@keras_export(["keras.ops.cholesky", "keras.ops.linalg.cholesky"])
def cholesky(x):
def cholesky(x, upper=False):
"""Computes the Cholesky decomposition of a positive semi-definite matrix.

Args:
x: Input tensor of shape `(..., M, M)`.
upper (bool): If True, returns the upper-triangular Cholesky factor.
If False (default), returns the lower-triangular Cholesky factor.

Returns:
A tensor of shape `(..., M, M)` representing the lower triangular
Cholesky factor of `x`.

A tensor of shape `(..., M, M)` representing the Cholesky factor of `x`.
"""
if any_symbolic_tensors((x,)):
return Cholesky().symbolic_call(x)
return _cholesky(x)
return Cholesky().symbolic_call(x, upper=upper)
return _cholesky(x, upper=upper)


def _cholesky(x):
def _cholesky(x, upper=False):
x = backend.convert_to_tensor(x)
_assert_2d(x)
_assert_square(x)
try:
return backend.linalg.cholesky(x)
return backend.linalg.cholesky(x, upper=upper)
except Exception as e:
raise ValueError(f"Cholesky decomposition failed: {e}")


class CholeskyInverse(Operation):
def call(self, x, upper=False):
return _cholesky_inverse(x, upper)

def compute_output_spec(self, x, upper=False):
_assert_2d(x)
_assert_square(x)
return KerasTensor(x.shape, x.dtype)


@keras_export(["keras.ops.cholesky_inverse", "keras.ops.linalg.cholesky_inverse"])
def cholesky_inverse(x, upper=False):
"""Computes the inverse of a symmetric positive-definite matrix.

Args:
x: Input tensor of shape `(..., M, M)`.
upper (bool): Determines whether to use the upper- or lower-triangular
factor for the internal computation. Defaults to False.

Returns:
A tensor of shape `(..., M, M)` representing the inverse of `x`.

Raises:
ValueError: If `x` is not a symmetric positive-definite matrix.
"""
if any_symbolic_tensors((x,)):
return CholeskyInverse().symbolic_call(x, upper=upper)
return _cholesky_inverse(x, upper=upper)


def _cholesky_inverse(x, upper=False):
x = backend.convert_to_tensor(x)
_assert_2d(x)
_assert_square(x)
try:
return backend.linalg.cholesky_inverse(x, upper=upper)
except Exception as e:
raise ValueError(f"Cholesky inverse failed: {e}")


class Det(Operation):
def call(self, x):
return _det(x)
Expand Down
59 changes: 54 additions & 5 deletions keras/src/ops/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def test_cholesky(self):
with self.assertRaises(ValueError):
linalg.cholesky(x)

def test_cholesky_inverse(self):
x = KerasTensor([None, 20, 20])
out = linalg.cholesky_inverse(x)
self.assertEqual(out.shape, (None, 20, 20))

x = KerasTensor([None, None, 20])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

x = KerasTensor([None, 20, 15])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

def test_det(self):
x = KerasTensor([None, 20, 20])
out = linalg.det(x)
Expand Down Expand Up @@ -196,6 +209,15 @@ def test_cholesky(self):
with self.assertRaises(ValueError):
linalg.cholesky(x)

def test_cholesky_inverse(self):
x = KerasTensor([4, 3, 3])
out = linalg.cholesky_inverse(x)
self.assertEqual(out.shape, (4, 3, 3))

x = KerasTensor([10, 20, 15])
with self.assertRaises(ValueError):
linalg.cholesky_inverse(x)

def test_det(self):
x = KerasTensor([4, 3, 3])
out = linalg.det(x)
Expand Down Expand Up @@ -331,12 +353,39 @@ def test_svd(self):

class LinalgOpsCorrectnessTest(testing.TestCase):
def test_cholesky(self):
x = np.random.rand(4, 3, 3).astype("float32")
x_non_psd = np.random.rand(4, 3, 3).astype("float32")
with self.assertRaises(ValueError):
linalg.cholesky(x)
x_psd = x @ x.transpose((0, 2, 1)) + 1e-5 * np.eye(3)
out = linalg.cholesky(x_psd)
self.assertAllClose(out, np.linalg.cholesky(x_psd), atol=1e-4)
linalg.cholesky(x_non_psd)
x = np.random.rand(4, 3, 3).astype("float32")
x_psd = np.matmul(x, x.transpose((0, 2, 1))) + 1e-5 * np.eye(3)

l_out = linalg.cholesky(x_psd, upper=False)
l_expected = np.linalg.cholesky(x_psd)
self.assertAllClose(l_out, l_expected, atol=1e-4)

u_out = linalg.cholesky(x_psd, upper=True)
u_expected = l_expected.transpose((0, 2, 1))
self.assertAllClose(u_out, u_expected, atol=1e-4)


def test_cholesky_inverse(self):
x_np = np.random.rand(3, 3).astype("float32")
x_psd_np = np.matmul(x_np, x_np.T) + 1e-4 * np.eye(3, dtype="float32")
identity = np.eye(3, dtype="float32")

l_factor_np = np.linalg.cholesky(x_psd_np)
x_inv_from_l = linalg.cholesky_inverse(l_factor_np, upper=False)
reconstructed_from_l = ops.matmul(x_psd_np, x_inv_from_l)
self.assertAllClose(
reconstructed_from_l, identity, atol=1e-4, rtol=1e-4
)

u_factor_np = l_factor_np.T
x_inv_from_u = linalg.cholesky_inverse(u_factor_np, upper=True)
reconstructed_from_u = ops.matmul(x_psd_np, x_inv_from_u)
self.assertAllClose(
reconstructed_from_u, identity, atol=1e-4, rtol=1e-4
)

def test_det(self):
x = np.random.rand(4, 3, 3)
Expand Down
Loading