@@ -185,8 +185,7 @@ def _to_tangent_plane_mu0(
185185 x = torch .Tensor (x ).reshape (- 1 , self .dim )
186186 if self .type == "E" :
187187 return x
188- else :
189- return torch .cat ([torch .zeros ((x .shape [0 ], 1 ), device = self .device ), x ], dim = 1 )
188+ return torch .cat ([torch .zeros ((x .shape [0 ], 1 ), device = self .device ), x ], dim = 1 )
190189
191190 def sample (
192191 self ,
@@ -277,31 +276,30 @@ def log_likelihood(
277276 if self .type == "E" :
278277 return torch .distributions .MultivariateNormal (mu , sigma ).log_prob (z )
279278
280- else :
281- u = self .manifold .logmap (x = mu , y = z ) # Map z to tangent space at mu
282- v = self .manifold .transp (x = mu , y = self .mu0 , v = u ) # Parallel transport to origin
283- # assert torch.allclose(v[:, 0], torch.Tensor([0.])) # For tangent vectors at origin this should be true
284- # OK, so this assertion doesn't actually pass, but it's spiritually true
285- if torch .isnan (v ).any ():
286- print ("NANs in parallel transport" )
287- v = torch .nan_to_num (v , nan = 0.0 )
288- N = torch .distributions .MultivariateNormal (torch .zeros (self .dim , device = self .device ), sigma )
289- ll = N .log_prob (v [:, 1 :])
290-
291- # For convenience
292- R = self .scale
293- n = self .dim
294-
295- # Final formula (epsilon to avoid log(0))
296- if self .type == "S" :
297- sin_M = torch .sin
298- u_norm = self .manifold .norm (x = mu , u = u )
279+ u = self .manifold .logmap (x = mu , y = z ) # Map z to tangent space at mu
280+ v = self .manifold .transp (x = mu , y = self .mu0 , v = u ) # Parallel transport to origin
281+ # assert torch.allclose(v[:, 0], torch.Tensor([0.])) # For tangent vectors at origin this should be true
282+ # OK, so this assertion doesn't actually pass, but it's spiritually true
283+ if torch .isnan (v ).any ():
284+ print ("NANs in parallel transport" )
285+ v = torch .nan_to_num (v , nan = 0.0 )
286+ N = torch .distributions .MultivariateNormal (torch .zeros (self .dim , device = self .device ), sigma )
287+ ll = N .log_prob (v [:, 1 :])
288+
289+ # For convenience
290+ R = self .scale
291+ n = self .dim
292+
293+ # Final formula (epsilon to avoid log(0))
294+ if self .type == "S" :
295+ sin_M = torch .sin
296+ u_norm = self .manifold .norm (x = mu , u = u )
299297
300- else :
301- sin_M = torch .sinh
302- u_norm = self .manifold .base .norm (u = u ) # Horrible workaround needed for geoopt bug # type: ignore
298+ else :
299+ sin_M = torch .sinh
300+ u_norm = self .manifold .base .norm (u = u ) # Horrible workaround needed for geoopt bug # type: ignore
303301
304- return ll - (n - 1 ) * torch .log (R * torch .abs (sin_M (u_norm / R ) / u_norm ) + 1e-8 )
302+ return ll - (n - 1 ) * torch .log (R * torch .abs (sin_M (u_norm / R ) / u_norm ) + 1e-8 )
305303
306304 def logmap (
307305 self , x : Float [torch .Tensor , "n_points n_dim" ], base : Optional [Float [torch .Tensor , "n_points n_dim" ]] = None
@@ -366,7 +364,7 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
366364 for X in denom :
367365 X [X .abs () < 1e-6 ] = 1e-6 # Avoid division by zero
368366 stereo_points = [n / d for n , d in zip (num , denom )]
369- assert all ([ stereo_manifold .manifold .check_point (X ) for X in stereo_points ] )
367+ assert all (stereo_manifold .manifold .check_point (X ) for X in stereo_points )
370368
371369 return stereo_manifold , * stereo_points
372370
@@ -579,7 +577,7 @@ def sample(
579577 for M , sigma in zip (self .P , sigma_factorized )
580578 ]
581579
582- assert sum ([ sigma .shape == (n , M .dim , M .dim ) for M , sigma in zip (self .P , sigma_factorized )]) == len ( self . P )
580+ assert all ( sigma .shape == (n , M .dim , M .dim ) for M , sigma in zip (self .P , sigma_factorized ))
583581 assert z_mean .shape [- 1 ] == self .ambient_dim
584582
585583 # Sample initial vector from N(0, sigma)
@@ -637,7 +635,7 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
637635 stereo_points = [
638636 torch .hstack ([M .stereographic (x )[1 ] for x , M in zip (self .factorize (X ), self .P )]) for X in points
639637 ]
640- assert all ([ stereo_manifold .manifold .check_point (X ) for X in stereo_points ] )
638+ assert all (stereo_manifold .manifold .check_point (X ) for X in stereo_points )
641639
642640 return stereo_manifold , * stereo_points
643641
@@ -652,7 +650,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
652650 orig_points = [
653651 torch .hstack ([M .inverse_stereographic (x )[1 ] for x , M in zip (self .factorize (X ), self .P )]) for X in points
654652 ]
655- assert all ([ orig_manifold .manifold .check_point (X ) for X in orig_points ] )
653+ assert all (orig_manifold .manifold .check_point (X ) for X in orig_points )
656654
657655 return orig_manifold , * orig_points
658656
0 commit comments