Skip to content

Commit 0ee0ba3

Browse files
committed
Fix tangential projection on the Cholesky manifold
1 parent 6ab8306 commit 0ee0ba3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorflow_riemopt/manifolds/cholesky.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def projx(self, x):
4646
def proju(self, x, u):
4747
u_sym = (utils.transposem(u) + u) / 2.0
4848
u_diag, u_lower = self._diag_and_strictly_lower(u_sym)
49-
return u_lower + tf.linalg.diag(u_diag)
49+
x_diag = tf.linalg.diag_part(x)
50+
return u_lower + tf.linalg.diag(u_diag * x_diag ** 2)
5051

5152
def inner(self, x, u, v, keepdims=False):
5253
u_diag, u_lower = self._diag_and_strictly_lower(u)

0 commit comments

Comments
 (0)