Skip to content

Commit cd7ec31

Browse files
Added cholesky inverse operation to all the backends (#21554)
* Added cholesky_inverse to all backends * Update keras/src/backend/openvino/linalg.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/ops/linalg.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Added more tests for cholesky inverse op * Addressed the comments * Formatted the code for extra spaces and modified test for numerical verification * Added upper bool to the cholesky and cholesky_inverse operator * Modified test case for numerical verification * nit changes in the openvino linalg file * Modifying tests * Addressing the comments * Modified test case for better tests * Changed the test case * modified the test case * Added more tests * Added more tests * correcting the test case * Modifying cholesky_inverse functions * Fixed nit comments --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 8c55abe commit cd7ec31

File tree

12 files changed

+185
-24
lines changed

12 files changed

+185
-24
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from keras.src.ops.core import while_loop as while_loop
3333
from keras.src.ops.einops import rearrange as rearrange
3434
from keras.src.ops.linalg import cholesky as cholesky
35+
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
3536
from keras.src.ops.linalg import det as det
3637
from keras.src.ops.linalg import eig as eig
3738
from keras.src.ops.linalg import eigh as eigh

keras/api/_tf_keras/keras/ops/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from keras.src.ops.linalg import cholesky as cholesky
8+
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
89
from keras.src.ops.linalg import det as det
910
from keras.src.ops.linalg import eig as eig
1011
from keras.src.ops.linalg import eigh as eigh

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from keras.src.ops.core import while_loop as while_loop
3333
from keras.src.ops.einops import rearrange as rearrange
3434
from keras.src.ops.linalg import cholesky as cholesky
35+
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
3536
from keras.src.ops.linalg import det as det
3637
from keras.src.ops.linalg import eig as eig
3738
from keras.src.ops.linalg import eigh as eigh

keras/api/ops/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from keras.src.ops.linalg import cholesky as cholesky
8+
from keras.src.ops.linalg import cholesky_inverse as cholesky_inverse
89
from keras.src.ops.linalg import det as det
910
from keras.src.ops.linalg import eig as eig
1011
from keras.src.ops.linalg import eigh as eigh

keras/src/backend/jax/linalg.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from keras.src.backend.jax.core import convert_to_tensor
1010

1111

12-
def cholesky(a):
13-
out = jnp.linalg.cholesky(a)
12+
def cholesky(a, upper=False):
13+
out = jnp.linalg.cholesky(a, upper=upper)
1414
try:
1515
# In eager mode, raise for nan to
1616
# achieve behavior consistency with numpy
@@ -26,6 +26,16 @@ def cholesky(a):
2626
return out
2727

2828

29+
def cholesky_inverse(a, upper=False):
30+
identity = jnp.eye(a.shape[-1], dtype=a.dtype)
31+
inv_chol = solve_triangular(a, identity, lower=not upper)
32+
if upper:
33+
a_inv = jnp.matmul(inv_chol, jnp.transpose(inv_chol))
34+
else:
35+
a_inv = jnp.matmul(jnp.transpose(inv_chol), inv_chol)
36+
return a_inv
37+
38+
2939
def det(a):
3040
return jnp.linalg.det(a)
3141

keras/src/backend/numpy/linalg.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@
66
from keras.src.backend.numpy.core import convert_to_tensor
77

88

9-
def cholesky(a):
10-
return np.linalg.cholesky(a)
9+
def cholesky(a, upper=False):
10+
return np.linalg.cholesky(a, upper=upper)
11+
12+
13+
def cholesky_inverse(a, upper=False):
14+
identity = np.eye(a.shape[-1], dtype=a.dtype)
15+
inv_chol = solve_triangular(a, identity, lower=not upper)
16+
if upper:
17+
a_inv = np.matmul(inv_chol, inv_chol.T)
18+
else:
19+
a_inv = np.matmul(inv_chol.T, inv_chol)
20+
return a_inv
1121

1222

1323
def det(a):

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,5 @@ TestMathErrors::test_istft_invalid_window_shape_2D_inputs
285285
TestMathErrors::test_stft_invalid_input_type
286286
TestMathErrors::test_stft_invalid_window
287287
TestMathErrors::test_stft_invalid_window_shape
288+
LinalgOpsCorrectnessTest::test_cholesky
289+
LinalgOpsCorrectnessTest::test_cholesky_inverse

keras/src/backend/openvino/linalg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
def cholesky(a):
1+
def cholesky(a, upper=False):
22
raise NotImplementedError(
3-
"`cholesky` is not supported with openvino backend"
3+
"`cholesky` is not supported with openvino backend."
4+
)
5+
6+
7+
def cholesky_inverse(a, upper=False):
8+
raise NotImplementedError(
9+
"`cholesky_inverse` is not supported with openvino backend."
410
)
511

612

keras/src/backend/tensorflow/linalg.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,23 @@
77
from keras.src.backend.tensorflow.core import convert_to_tensor
88

99

10-
def cholesky(a):
10+
def cholesky(a, upper=False):
1111
out = tf.linalg.cholesky(a)
1212
# tf.linalg.cholesky simply returns NaNs for non-positive definite matrices
13-
return tf.debugging.check_numerics(out, "Cholesky")
13+
out = tf.debugging.check_numerics(out, "Cholesky")
14+
if upper:
15+
return tf.linalg.adjoint(out)
16+
return out
17+
18+
19+
def cholesky_inverse(a, upper=False):
20+
identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype)
21+
inv_chol = tf.linalg.triangular_solve(a, identity, lower=not upper)
22+
if upper:
23+
a_inv = tf.matmul(inv_chol, inv_chol, transpose_b=True)
24+
else:
25+
a_inv = tf.matmul(inv_chol, inv_chol, transpose_a=True)
26+
return a_inv
1427

1528

1629
def det(a):

keras/src/backend/torch/linalg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
from keras.src.backend.torch.core import convert_to_tensor
88

99

10-
def cholesky(x):
11-
return torch.linalg.cholesky(x)
10+
def cholesky(x, upper=False):
11+
return torch.linalg.cholesky(x, upper=upper)
12+
13+
14+
def cholesky_inverse(x, upper=False):
15+
return torch.cholesky_inverse(x, upper=upper)
1216

1317

1418
def det(x):

0 commit comments

Comments
 (0)