@@ -196,10 +196,7 @@ def sample(
196196 sigma : Optional [Float [torch .Tensor , "n_points n_dim n_dim" ]] = None ,
197197 ) -> Union [
198198 Float [torch .Tensor , "n_points n_ambient_dim" ],
199- Tuple [
200- Float [torch .Tensor , "n_points n_ambient_dim" ],
201- Float [torch .Tensor , "n_points n_dim" ],
202- ],
199+ Tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points n_dim" ]],
203200 ]:
204201 """
205202 Sample from the variational distribution.
@@ -228,7 +225,7 @@ def sample(
228225 N = torch .distributions .MultivariateNormal (
229226 loc = torch .zeros ((n , self .dim ), device = self .device ), covariance_matrix = sigma
230227 )
231- v = N .sample () # type: ignore
228+ v = N .sample ()
232229
233230 # Don't need to adjust normal vectors for the Scaled manifold class in geoopt - very cool!
234231
@@ -356,14 +353,14 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
356353
357354 if self .is_stereographic :
358355 print ("Manifold is already in stereographic coordinates." )
359- return self , * points # type: ignore
356+ return self , * points
360357
361358 # Convert manifold
362359 stereo_manifold = Manifold (self .curvature , self .dim , device = self .device , stereographic = True )
363360
364361 # Euclidean edge case
365362 if self .type == "E" :
366- return stereo_manifold , * points # type: ignore
363+ return stereo_manifold , * points
367364
368365 # Convert points
369366 num = [X [:, 1 :] for X in points ]
@@ -373,7 +370,7 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
373370 stereo_points = [n / d for n , d in zip (num , denom )]
374371 assert all ([stereo_manifold .manifold .check_point (X ) for X in stereo_points ])
375372
376- return stereo_manifold , * stereo_points # type: ignore
373+ return stereo_manifold , * stereo_points
377374
378375 def inverse_stereographic (self , * points : Float [torch .Tensor , "n_points n_dim_stereo" ]) -> Tuple ["Manifold" , ...]:
379376 """
@@ -389,14 +386,14 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
389386 """
390387 if not self .is_stereographic :
391388 print ("Manifold is already in original coordinates." )
392- return self , * points # type: ignore
389+ return self , * points
393390
394391 # Convert manifold
395392 orig_manifold = Manifold (self .curvature , self .dim , device = self .device , stereographic = False )
396393
397394 # Euclidean edge case
398395 if self .type == "E" :
399- return orig_manifold , * points # type: ignore
396+ return orig_manifold , * points
400397
401398 # Inverse projection for points
402399 out = []
@@ -427,7 +424,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
427424
428425 out .append (inv_points )
429426
430- return orig_manifold , * out # type: ignore
427+ return orig_manifold , * out
431428
432429 def apply (self , f : Callable ) -> Callable :
433430 """
@@ -475,11 +472,11 @@ def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", ster
475472 # Actually initialize the geoopt manifolds; other derived properties
476473 self .P = [Manifold (curvature , dim , device = device , stereographic = stereographic ) for curvature , dim in signature ]
477474 manifold_class = geoopt .StereographicProductManifold if stereographic else geoopt .ProductManifold
478- self .manifold = manifold_class (* [(M .manifold , M .ambient_dim ) for M in self .P ]).to (device ) # type: ignore
475+ self .manifold = manifold_class (* [(M .manifold , M .ambient_dim ) for M in self .P ]).to (device )
479476 self .name = " x " .join ([M .name for M in self .P ])
480477
481478 # Origin
482- self .mu0 = torch .cat ([M .mu0 for M in self .P ], axis = 1 ).to (self .device ) # type: ignore
479+ self .mu0 = torch .cat ([M .mu0 for M in self .P ], axis = 1 ).to (self .device )
483480
484481 # Manifold <-> Dimension mapping
485482 self .ambient_dim , self .n_manifolds , self .dim = 0 , 0 , 0
@@ -507,11 +504,11 @@ def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", ster
507504 for j , k in zip (intrinsic_dims , ambient_dims [- len (intrinsic_dims ) :]):
508505 self .projection_matrix [j , k ] = 1.0
509506
510- def params (self ):
507+ def params (self ) -> List [ float ] :
511508 """Returns scales for all component manifolds"""
512509 return [x .scale () for x in self .manifold .manifolds ]
513510
514- def to (self , device : str ):
511+ def to (self , device : str ) -> "ProductManifold" :
515512 """Move all components to a new device"""
516513 self .device = device
517514 self .P = [M .to (device ) for M in self .P ]
@@ -628,24 +625,23 @@ def log_likelihood(
628625 M .log_likelihood (z_M , mu_M , sigma_M ).unsqueeze (dim = 1 )
629626 for M , z_M , mu_M , sigma_M in zip (self .P , z_factorized , mu_factorized , sigma_factorized )
630627 ]
631- return torch .cat (component_lls , axis = 1 ).sum (axis = 1 ) # type: ignore
628+ return torch .cat (component_lls , axis = 1 ).sum (axis = 1 )
632629
633630 def stereographic (self , * points : Float [torch .Tensor , "n_points n_dim" ]) -> Tuple ["ProductManifold" , ...]:
634631 if self .is_stereographic :
635632 print ("Manifold is already in stereographic coordinates." )
636- return self , * points # type: ignore
633+ return self , * points
637634
638635 # Convert manifold
639636 stereo_manifold = ProductManifold (self .signature , device = self .device , stereographic = True )
640637
641638 # Convert points
642639 stereo_points = [
643- torch .hstack ([M .stereographic (x )[1 ] for x , M in zip (self .factorize (X ), self .P )]) # type: ignore
644- for X in points
640+ torch .hstack ([M .stereographic (x )[1 ] for x , M in zip (self .factorize (X ), self .P )]) for X in points
645641 ]
646642 assert all ([stereo_manifold .manifold .check_point (X ) for X in stereo_points ])
647643
648- return stereo_manifold , * stereo_points # type: ignore
644+ return stereo_manifold , * stereo_points
649645
650646 def inverse_stereographic (self , * points : Float [torch .Tensor , "n_points n_dim_stereo" ]) -> Tuple [Manifold ]:
651647 if not self .is_stereographic :
@@ -660,9 +656,9 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
660656 ]
661657 assert all ([orig_manifold .manifold .check_point (X ) for X in orig_points ])
662658
663- return orig_manifold , * orig_points # type: ignore
659+ return orig_manifold , * orig_points
664660
665- @torch .no_grad ()
661+ @torch .no_grad () # type: ignore
666662 def gaussian_mixture (
667663 self ,
668664 num_points : int = 1_000 ,
@@ -713,7 +709,7 @@ def gaussian_mixture(
713709 z_mean = torch .stack ([self .mu0 ] * num_clusters ),
714710 sigma_factorized = [torch .stack ([torch .eye (M .dim )] * num_clusters ) * cov_scale_means for M in self .P ],
715711 )
716- assert cluster_means .shape == (num_clusters , self .ambient_dim ) # type: ignore
712+ assert cluster_means .shape == (num_clusters , self .ambient_dim )
717713
718714 # Generate class assignments
719715 cluster_probs = torch .rand (num_clusters )
@@ -726,10 +722,8 @@ def gaussian_mixture(
726722
727723 # Generate covariance matrices for each class - Wishart distribution
728724 cov_matrices = [
729- torch .distributions .Wishart (
730- df = M .dim + 1 , covariance_matrix = torch .eye (M .dim ) * cov_scale_points # type: ignore
731- ).sample (
732- sample_shape = (num_clusters ,) # type: ignore
725+ torch .distributions .Wishart (df = M .dim + 1 , covariance_matrix = torch .eye (M .dim ) * cov_scale_points ).sample (
726+ sample_shape = (num_clusters ,)
733727 )
734728 + torch .eye (M .dim ) * 1e-5 # jitter to avoid singularity
735729 for M in self .P
@@ -764,7 +758,7 @@ def gaussian_mixture(
764758
765759 # Noise component
766760 N = torch .distributions .Normal (0 , regression_noise_std )
767- v = N .sample ((num_points ,)).to (self .device ) # type: ignore
761+ v = N .sample ((num_points ,)).to (self .device )
768762 labels += v
769763
770764 # Normalize regression labels to range [0, 1] so that RMSE can be more easily interpreted
0 commit comments