@@ -211,18 +211,25 @@ def _to_tangent_plane_mu0(
211211
212212 def sample (
213213 self ,
214- z_mean : Float [torch .Tensor , "n_points n_ambient_dim" ] | None = None ,
214+ n_samples : int = 1 ,
215+ z_mean : Float [torch .Tensor , "n_points n_ambient_dim" ] | Float [torch .Tensor , "n_ambient_dim" ] | None = None ,
215216 sigma : Float [torch .Tensor , "n_points n_dim n_dim" ] | None = None ,
216- ) -> tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points n_dim" ]]:
217+ return_tangent : bool = False ,
218+ ) -> (
219+ tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points n_dim" ]]
220+ | Float [torch .Tensor , "n_points n_ambient_dim" ]
221+ ):
217222 """Sample points from the variational distribution on the manifold.
218223
219224 Args:
225+ n_samples: Number of points to sample.
220226 z_mean: Tensor representing the mean of the sample distribution.
221227 sigma: Optional tensor representing the covariance matrix. If None, defaults to an identity matrix.
228+ return_tangent: Whether to return the tangent vectors along with the sampled points.
222229
223230 Returns:
224231 x: Tensor of sampled points on the manifold
225- v: Tensor of tangent vectors
232+ v: Tensor of tangent vectors (if `return_tangent` is True).
226233 """
227234 z_mean = self .mu0 if z_mean is None else z_mean
228235 z_mean = torch .Tensor (z_mean ).reshape (- 1 , self .ambient_dim ).to (self .device )
@@ -237,6 +244,10 @@ def sample(
237244 assert torch .allclose (sigma , sigma .transpose (- 1 , - 2 )), "Covariance matrix must be symmetric"
238245 assert z_mean .shape [- 1 ] == self .ambient_dim , f"Expected z_mean shape { self .ambient_dim } , got { z_mean .shape [- 1 ]} "
239246
247+ # Adjust for n_points:
248+ z_mean = torch .repeat_interleave (z_mean , n_samples , dim = 0 )
249+ sigma = torch .repeat_interleave (sigma , n_samples , dim = 0 )
250+
240251 # Sample initial vector from N(0, sigma)
241252 N = torch .distributions .MultivariateNormal (
242253 loc = torch .zeros ((n , self .dim ), device = self .device ), covariance_matrix = sigma
@@ -260,8 +271,7 @@ def sample(
260271 # Exp map onto the manifold
261272 x = self .manifold .expmap (x = z_mean , u = z )
262273
263- # Different samples and tangent vectors
264- return x , v
274+ return (x , v ) if return_tangent else x
265275
266276 def log_likelihood (
267277 self ,
@@ -611,19 +621,26 @@ def factorize(
611621
612622 def sample (
613623 self ,
624+ n_samples : int = 1 ,
614625 z_mean : Float [torch .Tensor , "n_points n_ambient_dim" ] | None = None ,
615- sigma_factorized : list [Float [torch .Tensor , "n_points ..." ]] | None = None , # TODO: fix ... annotations
616- ) -> tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points total_intrinsic_dim" ]]:
626+ sigma_factorized : list [Float [torch .Tensor , "n_points ..." ]] | None = None ,
627+ return_tangent : bool = False ,
628+ ) -> (
629+ tuple [Float [torch .Tensor , "n_points n_ambient_dim" ], Float [torch .Tensor , "n_points total_intrinsic_dim" ]]
630+ | Float [torch .Tensor , "n_points n_ambient_dim" ]
631+ ):
617632 """Sample from the variational distribution.
618633
619634 Args:
635+ n_samples: Number of points to sample.
620636 z_mean: Tensor representing the mean of the sample distribution. If None, defaults to the origin `self.mu0`.
621637 sigma_factorized: List of tensors representing factorized covariance matrices for each manifold. If None,
622638 defaults to a list of identity matrices for each manifold.
639+ return_tangent: Whether to return the tangent vectors along with the sampled points.
623640
624641 Returns:
625642 x: Tensor of sampled points on the manifold
626- v: Tensor of tangent vectors
643+ v: Tensor of tangent vectors (if `return_tangent` is True).
627644 """
628645 z_mean = self .mu0 if z_mean is None else z_mean
629646 z_mean = torch .Tensor (z_mean ).reshape (- 1 , self .ambient_dim ).to (self .device )
@@ -637,24 +654,28 @@ def sample(
637654 for M , sigma in zip (self .P , sigma_factorized , strict = False )
638655 ]
639656
640- assert all (sigma .shape == (n , M .dim , M .dim ) for M , sigma in zip (self .P , sigma_factorized , strict = False )), (
641- "Sigma matrices must match the dimensions of the manifolds."
642- )
643- assert z_mean .shape [- 1 ] == self .ambient_dim , (
657+ # Adjust for n_points:
658+ z_mean = torch .repeat_interleave (z_mean , n_samples , dim = 0 )
659+ sigma_factorized = [torch .repeat_interleave (sigma , n_samples , dim = 0 ) for sigma in sigma_factorized ]
660+
661+ assert all (
662+ sigma .shape == (n * n_samples , M .dim , M .dim ) for M , sigma in zip (self .P , sigma_factorized , strict = False )
663+ ), "Sigma matrices must match the dimensions of the manifolds."
664+ assert z_mean .shape == (n * n_samples , self .ambient_dim ), (
644665 "z_mean must have the same ambient dimension as the product manifold."
645666 )
646667
647668 # Sample initial vector from N(0, sigma)
648669 samples = [
649- M .sample (z_M , sigma_M )
670+ M .sample (1 , z_M , sigma_M , return_tangent = True )
650671 for M , z_M , sigma_M in zip (self .P , self .factorize (z_mean ), sigma_factorized , strict = False )
651672 ]
652673
653674 x = torch .cat ([s [0 ] for s in samples ], dim = 1 )
654675 v = torch .cat ([s [1 ] for s in samples ], dim = 1 )
655676
656677 # Different samples and tangent vectors
657- return x , v
678+ return ( x , v ) if return_tangent else x
658679
659680 def log_likelihood (
660681 self ,
@@ -807,15 +828,13 @@ def gaussian_mixture(
807828 cov_scale_means /= self .dim
808829
809830 # Generate cluster means
810- cluster_means , _ = self .sample (
811- z_mean = torch .vstack ([self .mu0 ] * num_clusters ),
812- sigma_factorized = [torch .stack ([torch .eye (M .dim )] * num_clusters ) * cov_scale_means for M in self .P ],
813- )
831+ cluster_means = self .sample (num_clusters , sigma_factorized = [torch .eye (M .dim ) * cov_scale_means for M in self .P ])
814832 assert cluster_means .shape == (num_clusters , self .ambient_dim ), "Cluster means shape mismatch."
815833
816834 # Generate class assignments
817835 cluster_probs = torch .rand (num_clusters )
818836 cluster_probs /= cluster_probs .sum ()
837+
819838 # Draw cluster assignments: ensure at least 2 points per cluster. This is to ensure splits can always happen.
820839 cluster_assignments = torch .multinomial (input = cluster_probs , num_samples = num_points , replacement = True )
821840 while (cluster_assignments .bincount () < 2 ).any ():
@@ -835,7 +854,7 @@ def gaussian_mixture(
835854 sample_means = torch .stack ([cluster_means [c ] for c in cluster_assignments ])
836855 assert sample_means .shape == (num_points , self .ambient_dim ), "Sample means shape mismatch."
837856 sample_covs = [torch .stack ([cov_matrix [c ] for c in cluster_assignments ]) for cov_matrix in cov_matrices ]
838- samples , tangent_vals = self .sample (z_mean = sample_means , sigma_factorized = sample_covs )
857+ samples , tangent_vals = self .sample (z_mean = sample_means , sigma_factorized = sample_covs , return_tangent = True )
839858 assert samples .shape == (num_points , self .ambient_dim ), "Sample shape mismatch."
840859
841860 # Map clusters to classes
0 commit comments