@@ -224,14 +224,11 @@ def sample(
224224 x: Tensor of sampled points on the manifold
225225 v: Tensor of tangent vectors
226226 """
227- if z_mean is None :
228- z_mean = self .mu0
227+ z_mean = self .mu0 if z_mean is None else z_mean
229228 z_mean = torch .Tensor (z_mean ).reshape (- 1 , self .ambient_dim ).to (self .device )
230229 n = z_mean .shape [0 ]
231- if sigma is None :
232- sigma = torch .stack ([torch .eye (self .dim )] * n ).to (self .device )
233- else :
234- sigma = torch .Tensor (sigma ).reshape (- 1 , self .dim , self .dim ).to (self .device )
230+ sigma = torch .stack ([torch .eye (self .dim )] * n ).to (self .device ) if sigma is None else sigma
231+ sigma = torch .Tensor (sigma ).reshape (- 1 , self .dim , self .dim ).to (self .device )
235232 assert sigma .shape == (
236233 n ,
237234 self .dim ,
@@ -284,14 +281,11 @@ def log_likelihood(
284281 `mu` and covariance `sigma`.
285282 """
286283 # Default to mu=self.mu0 and sigma=I
287- if mu is None :
288- mu = self .mu0
284+ mu = self .mu0 if mu is None else mu
289285 mu = torch .Tensor (mu ).reshape (- 1 , self .ambient_dim ).to (self .device )
290286 n = mu .shape [0 ]
291- if sigma is None :
292- sigma = torch .stack ([torch .eye (self .dim )] * n ).to (self .device )
293- else :
294- sigma = torch .Tensor (sigma ).reshape (- 1 , self .dim , self .dim ).to (self .device )
287+ sigma = torch .stack ([torch .eye (self .dim )] * n ).to (self .device ) if sigma is None else sigma
288+ sigma = torch .Tensor (sigma ).reshape (- 1 , self .dim , self .dim ).to (self .device )
295289
296290 # Euclidean case is regular old Gaussian log-likelihood
297291 if self .type == "E" :
@@ -336,8 +330,7 @@ def logmap(
336330 Returns:
337331 logmap_result: Tensor representing the result of the logarithmic map from `base` to `x` on the manifold.
338332 """
339- if base is None :
340- base = self .mu0
333+ base = self .mu0 if base is None else base
341334 return self .manifold .logmap (x = base , y = x )
342335
343336 def expmap (
@@ -355,8 +348,7 @@ def expmap(
355348 Returns:
356349 expmap_result: Tensor representing the result of the exponential map applied to `u` at the base point.
357350 """
358- if base is None :
359- base = self .mu0
351+ base = self .mu0 if base is None else base
360352 return self .manifold .expmap (x = base , u = u )
361353
362354 def stereographic (self , * points : Float [torch .Tensor , "n_points n_dim" ]) -> tuple [Manifold , ...]:
@@ -633,18 +625,17 @@ def sample(
633625 x: Tensor of sampled points on the manifold
634626 v: Tensor of tangent vectors
635627 """
636- if z_mean is None :
637- z_mean = self .mu0
628+ z_mean = self .mu0 if z_mean is None else z_mean
638629 z_mean = torch .Tensor (z_mean ).reshape (- 1 , self .ambient_dim ).to (self .device )
639630 n = z_mean .shape [0 ]
640631
641- if sigma_factorized is None :
642- sigma_factorized = [torch .stack ([torch .eye (M .dim )] * n ) for M in self .P ]
643- else :
644- sigma_factorized = [
645- torch .Tensor (sigma ).reshape (- 1 , M .dim , M .dim ).to (self .device )
646- for M , sigma in zip (self .P , sigma_factorized , strict = False )
647- ]
632+ sigma_factorized = (
633+ [torch .stack ([torch .eye (M .dim )] * n ) for M in self .P ] if sigma_factorized is None else sigma_factorized
634+ )
635+ sigma_factorized = [
636+ torch .Tensor (sigma ).reshape (- 1 , M .dim , M .dim ).to (self .device )
637+ for M , sigma in zip (self .P , sigma_factorized , strict = False )
638+ ]
648639
649640 assert all (sigma .shape == (n , M .dim , M .dim ) for M , sigma in zip (self .P , sigma_factorized , strict = False )), (
650641 "Sigma matrices must match the dimensions of the manifolds."
@@ -684,12 +675,12 @@ def log_likelihood(
684675 `mu` and covariance `sigma`.
685676 """
686677 n = z .shape [0 ]
687- if mu is None :
688- mu = torch .vstack ([self .mu0 ] * n ).to (self .device )
678+ mu = torch .vstack ([self .mu0 ] * n ).to (self .device ) if mu is None else mu
689679
690- if sigma_factorized is None :
691- sigma_factorized = [torch .stack ([torch .eye (M .dim )] * n ) for M in self .P ]
692- # Note that this factorization assumes block-diagonal covariance matrices
680+ sigma_factorized = (
681+ [torch .stack ([torch .eye (M .dim )] * n ) for M in self .P ] if sigma_factorized is None else sigma_factorized
682+ )
683+ # Note that this factorization assumes block-diagonal covariance matrices
693684
694685 mu_factorized = self .factorize (mu )
695686 z_factorized = self .factorize (z )
@@ -807,10 +798,8 @@ def gaussian_mixture(
807798 torch .manual_seed (seed )
808799
809800 # Deal with clusters
810- if num_clusters is None :
811- num_clusters = num_classes
812- else :
813- assert num_clusters >= num_classes , "Number of clusters must be at least as large as number of classes."
801+ num_clusters = num_clusters or num_classes
802+ assert num_clusters >= num_classes , "Number of clusters must be at least as large as number of classes."
814803
815804 # Adjust covariance matrices for number of dimensions
816805 if adjust_for_dims :
0 commit comments