Skip to content

Commit 6d593b5

Browse files
committed
Fix code style
1 parent 4134045 commit 6d593b5

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tensorflow_riemopt/manifolds/cholesky.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def proju(self, x, u):
4747
u_sym = (utils.transposem(u) + u) / 2.0
4848
u_diag, u_lower = self._diag_and_strictly_lower(u_sym)
4949
x_diag = tf.linalg.diag_part(x)
50-
return u_lower + tf.linalg.diag(u_diag * x_diag ** 2)
50+
return u_lower + tf.linalg.diag(u_diag * x_diag**2)
5151

5252
def inner(self, x, u, v, keepdims=False):
5353
u_diag, u_lower = self._diag_and_strictly_lower(u)
@@ -92,7 +92,7 @@ def dist(self, x, y, keepdims=False):
9292
axis=-1,
9393
keepdims=keepdims,
9494
)
95-
return tf.math.sqrt(lower ** 2 + tf.reshape(diag, lower.shape) ** 2)
95+
return tf.math.sqrt(lower**2 + tf.reshape(diag, lower.shape) ** 2)
9696

9797
def ptransp(self, x, y, v):
9898
x_diag, _ = self._diag_and_strictly_lower(x)

tensorflow_riemopt/manifolds/poincare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _mobius_add(self, x, y):
5555
x_y = tf.reduce_sum(x * y, axis=-1, keepdims=True)
5656
k = tf.cast(self.k, x.dtype)
5757
return ((1 + 2 * k * x_y + k * y_2) * x + (1 - k * x_2) * y) / (
58-
1 + 2 * k * x_y + k ** 2 * x_2 * y_2
58+
1 + 2 * k * x_y + k**2 * x_2 * y_2
5959
)
6060

6161
def _mobius_scal_mul(self, x, r):
@@ -91,15 +91,15 @@ def _lambda(self, x, keepdims=False):
9191

9292
def inner(self, x, u, v, keepdims=False):
9393
lambda_x = self._lambda(x, keepdims=keepdims)
94-
return tf.reduce_sum(u * v, axis=-1, keepdims=keepdims) * lambda_x ** 2
94+
return tf.reduce_sum(u * v, axis=-1, keepdims=keepdims) * lambda_x**2
9595

9696
def norm(self, x, u, keepdims=False):
9797
lambda_x = self._lambda(x, keepdims=keepdims)
9898
return tf.linalg.norm(u, axis=-1, keepdims=keepdims) * lambda_x
9999

100100
def proju(self, x, u):
101101
lambda_x = self._lambda(x, keepdims=True)
102-
return u / lambda_x ** 2
102+
return u / lambda_x**2
103103

104104
def projx(self, x):
105105
sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype))

0 commit comments

Comments
 (0)