88includes functions for different key geometric operations.
99"""
1010
11+ from __future__ import annotations
12+
1113import warnings
12- from typing import Callable , List , Literal , Optional , Tuple , Union
14+ from typing import TYPE_CHECKING , Callable , List , Literal , Optional , Tuple , Union
1315
1416import geoopt
1517import torch
@@ -31,7 +33,13 @@ class Manifold:
3133 stereographic: (bool) Whether to use stereographic coordinates for the manifold.
3234 """
3335
34- def __init__ (self , curvature : float , dim : int , device : str = "cpu" , stereographic : bool = False ):
36+ def __init__ (
37+ self ,
38+ curvature : float ,
39+ dim : int ,
40+ device : str = "cpu" ,
41+ stereographic : bool = False ,
42+ ):
3543 # Device management
3644 self .device = device
3745
@@ -95,7 +103,9 @@ def to(self, device: str) -> "Manifold":
95103 return self
96104
97105 def inner (
98- self , X : Float [torch .Tensor , "n_points1 n_dim" ], Y : Float [torch .Tensor , "n_points2 n_dim" ]
106+ self ,
107+ X : Float [torch .Tensor , "n_points1 n_dim" ],
108+ Y : Float [torch .Tensor , "n_points2 n_dim" ],
99109 ) -> Float [torch .Tensor , "n_points1 n_points2" ]:
100110 """
101111 Compute the inner product of manifolds.
@@ -116,7 +126,9 @@ def inner(
116126 return X_fixed @ Y .T * scaler
117127
118128 def dist (
119- self , X : Float [torch .Tensor , "n_points1 n_dim" ], Y : Float [torch .Tensor , "n_points2 n_dim" ]
129+ self ,
130+ X : Float [torch .Tensor , "n_points1 n_dim" ],
131+ Y : Float [torch .Tensor , "n_points2 n_dim" ],
120132 ) -> Float [torch .Tensor , "n_points1 n_points2" ]:
121133 """
122134 Inherit distance function from the geoopt manifold.
@@ -131,7 +143,9 @@ def dist(
131143 return self .manifold .dist (X [:, None ], Y [None , :])
132144
133145 def dist2 (
134- self , X : Float [torch .Tensor , "n_points1 n_dim" ], Y : Float [torch .Tensor , "n_points2 n_dim" ]
146+ self ,
147+ X : Float [torch .Tensor , "n_points1 n_dim" ],
148+ Y : Float [torch .Tensor , "n_points2 n_dim" ],
135149 ) -> Float [torch .Tensor , "n_points1 n_points2" ]:
136150 """
137151 Inherit squared distance function from the geoopt manifold.
@@ -194,7 +208,10 @@ def sample(
194208 sigma : Optional [Float [torch .Tensor , "n_points n_dim n_dim" ]] = None ,
195209 ) -> Union [
196210 Float [torch .Tensor , "n_points n_ambient_dim" ],
197- Tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points n_dim" ]],
211+ Tuple [
212+ Float [torch .Tensor , "n_points n_ambient_dim" ],
213+ Float [torch .Tensor , "n_points n_dim" ],
214+ ],
198215 ]:
199216 """
200217 Sample from the variational distribution.
@@ -304,7 +321,9 @@ def log_likelihood(
304321 return ll - (n - 1 ) * torch .log (R * torch .abs (sin_M (u_norm / R ) / u_norm ) + 1e-8 )
305322
306323 def logmap (
307- self , x : Float [torch .Tensor , "n_points n_dim" ], base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None
324+ self ,
325+ x : Float [torch .Tensor , "n_points n_dim" ],
326+ base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None ,
308327 ) -> Float [torch .Tensor , "n_points n_dim" ]:
309328 """
310329 Logarithmic map of point on manifold x at base point.
@@ -322,7 +341,9 @@ def logmap(
322341 return self .manifold .logmap (x = base , y = x )
323342
324343 def expmap (
325- self , u : Float [torch .Tensor , "n_points n_dim" ], base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None
344+ self ,
345+ u : Float [torch .Tensor , "n_points n_dim" ],
346+ base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None ,
326347 ) -> Float [torch .Tensor , "n_points n_dim" ]:
327348 """
328349 Exponential map of tangent vector u at base point.
@@ -394,7 +415,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
394415 return orig_manifold , * points # type: ignore
395416
396417 # Inverse projection for points
397- norm_squared = [(Y ** 2 ).sum (dim = 1 , keepdim = True ) for Y in points ]
418+ norm_squared = [(Y ** 2 ).sum (dim = 1 , keepdim = True ) for Y in points ]
398419 sign = torch .sign (self .curvature ) # type: ignore
399420
400421 X0 = (1 + sign * norm_squared ) / (1 - sign * norm_squared )
@@ -436,7 +457,12 @@ class ProductManifold(Manifold):
436457 stereographic: (bool) Whether to use stereographic coordinates for the manifold.
437458 """
438459
439- def __init__ (self , signature : List [Tuple [float , int ]], device : str = "cpu" , stereographic : bool = False ):
460+ def __init__ (
461+ self ,
462+ signature : List [Tuple [float , int ]],
463+ device : str = "cpu" ,
464+ stereographic : bool = False ,
465+ ):
440466 # Device management
441467 self .device = device
442468
@@ -459,7 +485,12 @@ def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", ster
459485
460486 # Manifold <-> Dimension mapping
461487 self .ambient_dim , self .n_manifolds , self .dim = 0 , 0 , 0
462- self .dim2man , self .man2dim , self .man2intrinsic , self .intrinsic2man = {}, {}, {}, {}
488+ self .dim2man , self .man2dim , self .man2intrinsic , self .intrinsic2man = (
489+ {},
490+ {},
491+ {},
492+ {},
493+ )
463494
464495 for M in self .P :
465496 for d in range (self .ambient_dim , self .ambient_dim + M .ambient_dim ):
@@ -518,7 +549,10 @@ def sample(
518549 sigma_factorized : Optional [List [Float [torch .Tensor , "n_points n_dim_manifold n_dim_manifold" ]]] = None ,
519550 ) -> Union [
520551 Float [torch .Tensor , "n_points n_ambient_dim" ],
521- Tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points n_dim" ]],
552+ Tuple [
553+ Float [torch .Tensor , "n_points n_ambient_dim" ],
554+ Float [torch .Tensor , "n_points n_dim" ],
555+ ],
522556 ]:
523557 """
524558 Sample from the variational distribution.
0 commit comments