@@ -55,7 +55,7 @@ def _mobius_add(self, x, y):
55
55
x_y = tf .reduce_sum (x * y , axis = - 1 , keepdims = True )
56
56
k = tf .cast (self .k , x .dtype )
57
57
return ((1 + 2 * k * x_y + k * y_2 ) * x + (1 - k * x_2 ) * y ) / (
58
- 1 + 2 * k * x_y + k ** 2 * x_2 * y_2
58
+ 1 + 2 * k * x_y + k ** 2 * x_2 * y_2
59
59
)
60
60
61
61
def _mobius_scal_mul (self , x , r ):
@@ -91,15 +91,15 @@ def _lambda(self, x, keepdims=False):
91
91
92
92
def inner (self , x , u , v , keepdims = False ):
93
93
lambda_x = self ._lambda (x , keepdims = keepdims )
94
- return tf .reduce_sum (u * v , axis = - 1 , keepdims = keepdims ) * lambda_x ** 2
94
+ return tf .reduce_sum (u * v , axis = - 1 , keepdims = keepdims ) * lambda_x ** 2
95
95
96
96
def norm (self , x , u , keepdims = False ):
97
97
lambda_x = self ._lambda (x , keepdims = keepdims )
98
98
return tf .linalg .norm (u , axis = - 1 , keepdims = keepdims ) * lambda_x
99
99
100
100
def proju (self , x , u ):
101
101
lambda_x = self ._lambda (x , keepdims = True )
102
- return u / lambda_x ** 2
102
+ return u / lambda_x ** 2
103
103
104
104
def projx (self , x ):
105
105
sqrt_k = tf .math .sqrt (tf .cast (self .k , x .dtype ))
0 commit comments