1111from __future__ import annotations
1212
1313import warnings
14- from typing import TYPE_CHECKING , Callable , List , Literal , Optional , Tuple , Union
14+ from typing import Callable , List , Literal , Optional , Tuple , Union
1515
1616import geoopt
1717import torch
@@ -33,13 +33,7 @@ class Manifold:
3333 stereographic: (bool) Whether to use stereographic coordinates for the manifold.
3434 """
3535
36- def __init__ (
37- self ,
38- curvature : float ,
39- dim : int ,
40- device : str = "cpu" ,
41- stereographic : bool = False ,
42- ):
36+ def __init__ (self , curvature : float , dim : int , device : str = "cpu" , stereographic : bool = False ):
4337 # Device management
4438 self .device = device
4539
@@ -103,9 +97,7 @@ def to(self, device: str) -> "Manifold":
10397 return self
10498
10599 def inner (
106- self ,
107- X : Float [torch .Tensor , "n_points1 n_dim" ],
108- Y : Float [torch .Tensor , "n_points2 n_dim" ],
100+ self , X : Float [torch .Tensor , "n_points1 n_dim" ], Y : Float [torch .Tensor , "n_points2 n_dim" ]
109101 ) -> Float [torch .Tensor , "n_points1 n_points2" ]:
110102 """
111103 Compute the inner product of manifolds.
@@ -126,9 +118,7 @@ def inner(
126118 return X_fixed @ Y .T * scaler
127119
128120 def dist (
129- self ,
130- X : Float [torch .Tensor , "n_points1 n_dim" ],
131- Y : Float [torch .Tensor , "n_points2 n_dim" ],
121+ self , X : Float [torch .Tensor , "n_points1 n_dim" ], Y : Float [torch .Tensor , "n_points2 n_dim" ]
132122 ) -> Float [torch .Tensor , "n_points1 n_points2" ]:
133123 """
134124 Inherit distance function from the geoopt manifold.
@@ -143,9 +133,7 @@ def dist(
143133 return self .manifold .dist (X [:, None ], Y [None , :])
144134
145135 def dist2 (
146- self ,
147- X : Float [torch .Tensor , "n_points1 n_dim" ],
148- Y : Float [torch .Tensor , "n_points2 n_dim" ],
136+ self , X : Float [torch .Tensor , "n_points1 n_dim" ], Y : Float [torch .Tensor , "n_points2 n_dim" ]
149137 ) -> Float [torch .Tensor , "n_points1 n_points2" ]:
150138 """
151139 Inherit squared distance function from the geoopt manifold.
@@ -265,7 +253,7 @@ def log_likelihood(
265253 z : Float [torch .Tensor , "n_points n_ambient_dim" ],
266254 mu : Optional [Float [torch .Tensor , "n_points n_ambient_dim" ]] = None ,
267255 sigma : Optional [Float [torch .Tensor , "n_points n_dim n_dim" ]] = None ,
268- ) -> Float [torch .Tensor , "n_points" ]:
256+ ) -> Float [torch .Tensor , "n_points, " ]:
269257 """
270258 Probability density function for WN(z ; mu, Sigma) in manifold
271259
@@ -321,9 +309,7 @@ def log_likelihood(
321309 return ll - (n - 1 ) * torch .log (R * torch .abs (sin_M (u_norm / R ) / u_norm ) + 1e-8 )
322310
323311 def logmap (
324- self ,
325- x : Float [torch .Tensor , "n_points n_dim" ],
326- base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None ,
312+ self , x : Float [torch .Tensor , "n_points n_dim" ], base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None
327313 ) -> Float [torch .Tensor , "n_points n_dim" ]:
328314 """
329315 Logarithmic map of point on manifold x at base point.
@@ -341,9 +327,7 @@ def logmap(
341327 return self .manifold .logmap (x = base , y = x )
342328
343329 def expmap (
344- self ,
345- u : Float [torch .Tensor , "n_points n_dim" ],
346- base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None ,
330+ self , u : Float [torch .Tensor , "n_points n_dim" ], base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None
347331 ) -> Float [torch .Tensor , "n_points n_dim" ]:
348332 """
349333 Exponential map of tangent vector u at base point.
@@ -415,7 +399,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
415399 return orig_manifold , * points # type: ignore
416400
417401 # Inverse projection for points
418- norm_squared = [(Y ** 2 ).sum (dim = 1 , keepdim = True ) for Y in points ]
402+ norm_squared = [(Y ** 2 ).sum (dim = 1 , keepdim = True ) for Y in points ]
419403 sign = torch .sign (self .curvature ) # type: ignore
420404
421405 X0 = (1 + sign * norm_squared ) / (1 - sign * norm_squared )
@@ -457,12 +441,7 @@ class ProductManifold(Manifold):
457441 stereographic: (bool) Whether to use stereographic coordinates for the manifold.
458442 """
459443
460- def __init__ (
461- self ,
462- signature : List [Tuple [float , int ]],
463- device : str = "cpu" ,
464- stereographic : bool = False ,
465- ):
444+ def __init__ (self , signature : List [Tuple [float , int ]], device : str = "cpu" , stereographic : bool = False ):
466445 # Device management
467446 self .device = device
468447
@@ -485,12 +464,7 @@ def __init__(
485464
486465 # Manifold <-> Dimension mapping
487466 self .ambient_dim , self .n_manifolds , self .dim = 0 , 0 , 0
488- self .dim2man , self .man2dim , self .man2intrinsic , self .intrinsic2man = (
489- {},
490- {},
491- {},
492- {},
493- )
467+ self .dim2man , self .man2dim , self .man2intrinsic , self .intrinsic2man = {}, {}, {}, {}
494468
495469 for M in self .P :
496470 for d in range (self .ambient_dim , self .ambient_dim + M .ambient_dim ):
@@ -549,10 +523,7 @@ def sample(
549523 sigma_factorized : Optional [List [Float [torch .Tensor , "n_points n_dim_manifold n_dim_manifold" ]]] = None ,
550524 ) -> Union [
551525 Float [torch .Tensor , "n_points n_ambient_dim" ],
552- Tuple [
553- Float [torch .Tensor , "n_points n_ambient_dim" ],
554- Float [torch .Tensor , "n_points n_dim" ],
555- ],
526+ Tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points n_dim" ]],
556527 ]:
557528 """
558529 Sample from the variational distribution.
@@ -593,9 +564,9 @@ def sample(
593564 def log_likelihood (
594565 self ,
595566 z : Float [torch .Tensor , "batch_size n_dim" ],
596- mu : Optional [Float [torch .Tensor , "n_dim" ]] = None ,
567+ mu : Optional [Float [torch .Tensor , "n_dim, " ]] = None ,
597568 sigma_factorized : Optional [List [Float [torch .Tensor , "n_points n_dim_manifold n_dim_manifold" ]]] = None ,
598- ) -> Float [torch .Tensor , "batch_size" ]:
569+ ) -> Float [torch .Tensor , "batch_size, " ]:
599570 """
600571 Probability density function for WN(z ; mu, Sigma) in manifold
601572
@@ -624,7 +595,7 @@ def log_likelihood(
624595 ]
625596 return torch .cat (component_lls , axis = 1 ).sum (axis = 1 ) # type: ignore
626597
627- def stereographic (self , * points : Float [torch .Tensor , "n_points n_dim" ]) -> Tuple [Manifold , ...]:
598+ def stereographic (self , * points : Float [torch .Tensor , "n_points n_dim" ]) -> Tuple ["ProductManifold" , ...]:
628599 if self .is_stereographic :
629600 print ("Manifold is already in stereographic coordinates." )
630601 return self , * points # type: ignore
@@ -653,7 +624,7 @@ def gaussian_mixture(
653624 regression_noise_std : float = 0.1 ,
654625 task : Literal ["classification" , "regression" ] = "classification" ,
655626 adjust_for_dims : bool = False ,
656- ) -> Tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points" ]]:
627+ ) -> Tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points, " ]]:
657628 """
658629 Generate a set of labeled samples from a Gaussian mixture model.
659630
0 commit comments