diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 0d4037717cc2..5ff4561e9d5e 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -32,6 +32,7 @@ from keras.src.ops.core import while_loop as while_loop from keras.src.ops.einops import rearrange as rearrange from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse from keras.src.ops.linalg import det as det from keras.src.ops.linalg import eig as eig from keras.src.ops.linalg import eigh as eigh diff --git a/keras/api/_tf_keras/keras/ops/linalg/__init__.py b/keras/api/_tf_keras/keras/ops/linalg/__init__.py index bc091ea766a4..0c96c3bbb8dc 100644 --- a/keras/api/_tf_keras/keras/ops/linalg/__init__.py +++ b/keras/api/_tf_keras/keras/ops/linalg/__init__.py @@ -5,6 +5,7 @@ """ from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse from keras.src.ops.linalg import det as det from keras.src.ops.linalg import eig as eig from keras.src.ops.linalg import eigh as eigh diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 0d4037717cc2..5ff4561e9d5e 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -32,6 +32,7 @@ from keras.src.ops.core import while_loop as while_loop from keras.src.ops.einops import rearrange as rearrange from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse from keras.src.ops.linalg import det as det from keras.src.ops.linalg import eig as eig from keras.src.ops.linalg import eigh as eigh diff --git a/keras/api/ops/linalg/__init__.py b/keras/api/ops/linalg/__init__.py index bc091ea766a4..0c96c3bbb8dc 100644 --- a/keras/api/ops/linalg/__init__.py +++ b/keras/api/ops/linalg/__init__.py @@ -5,6 +5,7 @@ """ from keras.src.ops.linalg import cholesky as cholesky +from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse from keras.src.ops.linalg import det as det from keras.src.ops.linalg import eig as eig from keras.src.ops.linalg import eigh as eigh diff --git a/keras/src/backend/jax/linalg.py b/keras/src/backend/jax/linalg.py index 05a623d89101..f36b80dcb09d 100644 --- a/keras/src/backend/jax/linalg.py +++ b/keras/src/backend/jax/linalg.py @@ -9,8 +9,8 @@ from keras.src.backend.jax.core import convert_to_tensor -def cholesky(a): - out = jnp.linalg.cholesky(a) +def cholesky(a, upper=False): + out = jnp.linalg.cholesky(a, upper=upper) try: # In eager mode, raise for nan to # achieve behavior consistency with numpy @@ -26,6 +26,16 @@ def cholesky(a): return out +def cholesky_inverse(a, upper=False): + identity = jnp.eye(a.shape[-1], dtype=a.dtype) + inv_chol = solve_triangular(a, identity, lower=not upper) + if upper: + a_inv = jnp.matmul(inv_chol, jnp.transpose(inv_chol)) + else: + a_inv = jnp.matmul(jnp.transpose(inv_chol), inv_chol) + return a_inv + + def det(a): return jnp.linalg.det(a) diff --git a/keras/src/backend/numpy/linalg.py b/keras/src/backend/numpy/linalg.py index 30881964f7c5..9fe27f6aac11 100644 --- a/keras/src/backend/numpy/linalg.py +++ b/keras/src/backend/numpy/linalg.py @@ -6,8 +6,18 @@ from keras.src.backend.numpy.core import convert_to_tensor -def cholesky(a): - return np.linalg.cholesky(a) +def cholesky(a, upper=False): + return np.linalg.cholesky(a, upper=upper) + + +def cholesky_inverse(a, upper=False): + identity = np.eye(a.shape[-1], dtype=a.dtype) + inv_chol = solve_triangular(a, identity, lower=not upper) + if upper: + a_inv = np.matmul(inv_chol, inv_chol.T) + else: + a_inv = np.matmul(inv_chol.T, inv_chol) + return a_inv def det(a): diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index 7968a7931b89..01f413d65aee 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -289,3 +289,5 @@ TestMathErrors::test_istft_invalid_window_shape_2D_inputs TestMathErrors::test_stft_invalid_input_type TestMathErrors::test_stft_invalid_window TestMathErrors::test_stft_invalid_window_shape +LinalgOpsCorrectnessTest::test_cholesky +LinalgOpsCorrectnessTest::test_cholesky_inverse diff --git a/keras/src/backend/openvino/linalg.py b/keras/src/backend/openvino/linalg.py index 3703bd83a0c1..948c4d480144 100644 --- a/keras/src/backend/openvino/linalg.py +++ b/keras/src/backend/openvino/linalg.py @@ -1,6 +1,12 @@ -def cholesky(a): +def cholesky(a, upper=False): raise NotImplementedError( - "`cholesky` is not supported with openvino backend" + "`cholesky` is not supported with openvino backend." + ) + + +def cholesky_inverse(a, upper=False): + raise NotImplementedError( + "`cholesky_inverse` is not supported with openvino backend." ) diff --git a/keras/src/backend/tensorflow/linalg.py b/keras/src/backend/tensorflow/linalg.py index da1ff4259685..9a1f1b615249 100644 --- a/keras/src/backend/tensorflow/linalg.py +++ b/keras/src/backend/tensorflow/linalg.py @@ -7,10 +7,23 @@ from keras.src.backend.tensorflow.core import convert_to_tensor -def cholesky(a): +def cholesky(a, upper=False): out = tf.linalg.cholesky(a) # tf.linalg.cholesky simply returns NaNs for non-positive definite matrices - return tf.debugging.check_numerics(out, "Cholesky") + out = tf.debugging.check_numerics(out, "Cholesky") + if upper: + return tf.linalg.adjoint(out) + return out + + +def cholesky_inverse(a, upper=False): + identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype) + inv_chol = tf.linalg.triangular_solve(a, identity, lower=not upper) + if upper: + a_inv = tf.matmul(inv_chol, inv_chol, transpose_b=True) + else: + a_inv = tf.matmul(inv_chol, inv_chol, transpose_a=True) + return a_inv def det(a): diff --git a/keras/src/backend/torch/linalg.py b/keras/src/backend/torch/linalg.py index 939074a680cd..bae9733e36a4 100644 --- a/keras/src/backend/torch/linalg.py +++ b/keras/src/backend/torch/linalg.py @@ -7,8 +7,12 @@ from keras.src.backend.torch.core import convert_to_tensor -def cholesky(x): - return torch.linalg.cholesky(x) +def cholesky(x, upper=False): + return torch.linalg.cholesky(x, upper=upper) + + +def cholesky_inverse(x, upper=False): + return torch.cholesky_inverse(x, upper=upper) def det(x): diff --git a/keras/src/ops/linalg.py b/keras/src/ops/linalg.py index dc8004f309fd..c294454dd5b2 100644 --- a/keras/src/ops/linalg.py +++ b/keras/src/ops/linalg.py @@ -7,8 +7,12 @@ class Cholesky(Operation): + def __init__(self, upper=False, *, name=None): + super().__init__(name=name) + self.upper = upper + def call(self, x): - return _cholesky(x) + return _cholesky(x, self.upper) def compute_output_spec(self, x): _assert_2d(x) @@ -17,32 +21,78 @@ def compute_output_spec(self, x): @keras_export(["keras.ops.cholesky", "keras.ops.linalg.cholesky"]) -def cholesky(x): +def cholesky(x, upper=False): """Computes the Cholesky decomposition of a positive semi-definite matrix. Args: x: Input tensor of shape `(..., M, M)`. + upper (bool): If True, returns the upper-triangular Cholesky factor. + If False (default), returns the lower-triangular Cholesky factor. Returns: - A tensor of shape `(..., M, M)` representing the lower triangular - Cholesky factor of `x`. - + A tensor of shape `(..., M, M)` representing the Cholesky factor of `x`. """ if any_symbolic_tensors((x,)): - return Cholesky().symbolic_call(x) - return _cholesky(x) + return Cholesky(upper=upper).symbolic_call(x) + return _cholesky(x, upper=upper) -def _cholesky(x): +def _cholesky(x, upper=False): x = backend.convert_to_tensor(x) _assert_2d(x) _assert_square(x) try: - return backend.linalg.cholesky(x) + return backend.linalg.cholesky(x, upper=upper) except Exception as e: raise ValueError(f"Cholesky decomposition failed: {e}") +class CholeskyInverse(Operation): + def __init__(self, upper=False, *, name=None): + super().__init__(name=name) + self.upper = upper + + def call(self, x): + return _cholesky_inverse(x, self.upper) + + def compute_output_spec(self, x): + _assert_2d(x) + _assert_square(x) + return KerasTensor(x.shape, x.dtype) + + +@keras_export( + ["keras.ops.cholesky_inverse", "keras.ops.linalg.cholesky_inverse"] +) +def cholesky_inverse(x, upper=False): + """Computes the inverse of a symmetric positive-definite matrix. + + Args: + x: Input tensor of shape `(..., M, M)`. + upper (bool): Determines whether to use the upper- or lower-triangular + factor for the internal computation. Defaults to False. + + Returns: + A tensor of shape `(..., M, M)` representing the inverse of `x`. + + Raises: + ValueError: If `x` is not a symmetric positive-definite matrix. + """ + if any_symbolic_tensors((x,)): + return CholeskyInverse(upper=upper).symbolic_call(x) + return _cholesky_inverse(x, upper=upper) + + +def _cholesky_inverse(x, upper=False): + x = backend.convert_to_tensor(x) + _assert_2d(x) + _assert_square(x) + try: + return backend.linalg.cholesky_inverse(x, upper=upper) + except Exception as e: + raise ValueError(f"Cholesky inverse failed: {e}") + + class Det(Operation): def call(self, x): return _det(x) diff --git a/keras/src/ops/linalg_test.py b/keras/src/ops/linalg_test.py index 67d72b32eee8..5ae00af915c8 100644 --- a/keras/src/ops/linalg_test.py +++ b/keras/src/ops/linalg_test.py @@ -23,6 +23,19 @@ def test_cholesky(self): with self.assertRaises(ValueError): linalg.cholesky(x) + def test_cholesky_inverse(self): + x = KerasTensor([None, 20, 20]) + out = linalg.cholesky_inverse(x) + self.assertEqual(out.shape, (None, 20, 20)) + + x = KerasTensor([None, None, 20]) + with self.assertRaises(ValueError): + linalg.cholesky_inverse(x) + + x = KerasTensor([None, 20, 15]) + with self.assertRaises(ValueError): + linalg.cholesky_inverse(x) + def test_det(self): x = KerasTensor([None, 20, 20]) out = linalg.det(x) @@ -196,6 +209,15 @@ def test_cholesky(self): with self.assertRaises(ValueError): linalg.cholesky(x) + def test_cholesky_inverse(self): + x = KerasTensor([4, 3, 3]) + out = linalg.cholesky_inverse(x) + self.assertEqual(out.shape, (4, 3, 3)) + + x = KerasTensor([10, 20, 15]) + with self.assertRaises(ValueError): + linalg.cholesky_inverse(x) + def test_det(self): x = KerasTensor([4, 3, 3]) out = linalg.det(x) @@ -331,12 +353,52 @@ def test_svd(self): class LinalgOpsCorrectnessTest(testing.TestCase): def test_cholesky(self): - x = np.random.rand(4, 3, 3).astype("float32") + x_non_psd = np.random.rand(4, 3, 3).astype("float32") with self.assertRaises(ValueError): - linalg.cholesky(x) - x_psd = x @ x.transpose((0, 2, 1)) + 1e-5 * np.eye(3) - out = linalg.cholesky(x_psd) - self.assertAllClose(out, np.linalg.cholesky(x_psd), atol=1e-4) + linalg.cholesky(x_non_psd) + + x = np.random.rand(4, 3, 3).astype("float32") + x_psd = np.matmul(x, x.transpose((0, 2, 1))) + 1e-5 * np.eye( + 3, dtype="float32" + ) + + l_out = linalg.cholesky(x_psd, upper=False) + l_expected = np.linalg.cholesky(x_psd) + self.assertAllClose(l_out, l_expected, atol=1e-4) + + u_out = linalg.cholesky(x_psd, upper=True) + u_expected = l_expected.transpose((0, 2, 1)) + self.assertAllClose(u_out, u_expected, atol=1e-4) + + @parameterized.named_parameters( + {"testcase_name": "lower", "upper": False}, + {"testcase_name": "upper", "upper": True}, + ) + def test_cholesky_inverse(self, upper): + A = np.array( + [ + [4.0, 12.0, -16.0], + [12.0, 37.0, -43.0], + [-16.0, -43.0, 98.0], + ], + dtype="float32", + ) + if upper: + factor = np.linalg.cholesky(A, upper=True) + else: + factor = np.linalg.cholesky(A) + + expected_inverse = np.array( + [ + [49.36111, -13.555555, 2.111111], + [-13.555555, 3.777778, -0.555556], + [2.111111, -0.555556, 0.111111], + ], + dtype="float32", + ) + + output_inverse = linalg.cholesky_inverse(factor, upper=upper) + self.assertAllClose(output_inverse, expected_inverse, atol=1e-5) def test_det(self): x = np.random.rand(4, 3, 3)