Skip to content

Commit 948ffa4

Browse files
committed
Add n_samples to sampling; add tutorial
1 parent e0f26ad commit 948ffa4

File tree

10 files changed

+1423
-44
lines changed

10 files changed

+1423
-44
lines changed

.github/workflows/test.yml

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,6 @@ jobs:
2929
python-version: ${{ matrix.python-version }}
3030
cache: "pip"
3131

32-
# Cache HuggingFace - this saves time running tests/test_utils.py on subsequent runs
33-
- name: Cache HuggingFace datasets
34-
uses: actions/cache@v4
35-
with:
36-
path: ~/.cache/huggingface
37-
key: ${{ runner.os }}-huggingface-${{ hashFiles('tests/test_*.py') }}
38-
restore-keys: |
39-
${{ runner.os }}-huggingface-
40-
4132
- name: Install dependencies
4233
run: |
4334
python -m pip install --upgrade pip
@@ -67,8 +58,8 @@ jobs:
6758
run: mypy manify/
6859

6960
# Unit testing
70-
- name: Run unit tests & collect coverage
71-
run: pytest tests --cov --cov-report=xml
61+
- name: Run unit tests & collect coverage (except dataloaders)
62+
run: pytest tests --cov --cov-report=xml -k "not test_dataloaders"
7263

7364

7465
# Code coverage
@@ -79,3 +70,46 @@ jobs:
7970
fail_ci_if_error: false
8071
verbose: true
8172
flags: unittests
73+
name: python-${{ matrix.python-version }}
74+
75+
# Dataloaders run in parallel, for speed
76+
test-dataloaders:
77+
runs-on: ubuntu-latest
78+
79+
steps:
80+
- name: Check out code
81+
uses: actions/checkout@v4
82+
83+
- name: Set up Python 3.11
84+
uses: actions/setup-python@v5
85+
with:
86+
python-version: "3.11"
87+
cache: "pip"
88+
89+
- name: Cache HuggingFace datasets
90+
uses: actions/cache@v4
91+
with:
92+
path: ~/.cache/huggingface
93+
key: ${{ runner.os }}-huggingface-dataloaders-v1
94+
restore-keys: |
95+
${{ runner.os }}-huggingface-dataloaders-
96+
${{ runner.os }}-huggingface-
97+
98+
- name: Install dependencies
99+
run: |
100+
python -m pip install --upgrade pip
101+
pip install -e ".[dev]"
102+
103+
- name: Run dataloader tests
104+
run: pytest tests/test_utils.py::test_dataloaders -v --cov=manify/dataloaders --cov-report=xml
105+
106+
# Upload dataloader coverage separately
107+
- name: Upload dataloader coverage to Codecov
108+
uses: codecov/codecov-action@v5
109+
with:
110+
token: ${{ secrets.CODECOV_TOKEN }}
111+
fail_ci_if_error: false
112+
verbose: true
113+
flags: dataloaders
114+
name: dataloaders
115+

manify/clustering/fuzzy_kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _init_centers(self, X: Float[torch.Tensor, "n_points n_features"]) -> None:
115115
# If we provide self.manifold.mu0 repeated n_clusters times,
116116
# it samples n_clusters points, each around mu0.
117117
means_for_sampling_centers = self.manifold.mu0.repeat(self.n_clusters, 1)
118-
centers, _ = self.manifold.sample(z_mean=means_for_sampling_centers)
118+
centers = self.manifold.sample(z_mean=means_for_sampling_centers)
119119

120120
# IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121121
# as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.

manify/curvature_estimation/_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def distortion_pipeline(
5353
return float(distortion_loss(new_dists, dists_rescaled).item())
5454

5555

56-
def classifier_pipeline(
56+
def predictor_pipeline(
5757
pm: ProductManifold,
5858
dists: Float[torch.Tensor, "n_nodes n_nodes"],
5959
labels: Float[torch.Tensor, "n_nodes"],

manify/embedders/coordinate_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def fit( # type: ignore[override]
121121
n = D.shape[0]
122122
covs = [torch.stack([torch.eye(M.dim) / self.pm.dim] * n).to(self.device) for M in self.pm.P]
123123
means = torch.vstack([self.pm.mu0] * n).to(self.device)
124-
X_embed, _ = self.pm.sample(z_mean=means, sigma_factorized=covs)
124+
X_embed = self.pm.sample(z_mean=means, sigma_factorized=covs)
125125
D = D.to(self.device)
126126

127127
# Get train and test indices set up

manify/embedders/vae.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def forward(
180180
sigmas = [torch.diag_embed(torch.exp(z_logvar) + 1e-8) for z_logvar in sigma_factorized]
181181

182182
# Sample and decode
183-
z, _ = self.pm.sample(z_means, sigmas)
183+
z = self.pm.sample(z_mean=z_means, sigma_factorized=sigmas)
184184
x_reconstructed = self.decode(z)
185185
return x_reconstructed, z_means, sigmas
186186

@@ -213,7 +213,8 @@ def kl_divergence(
213213
sigmas_factorized_interleaved = [
214214
torch.repeat_interleave(sigma, self.n_samples, dim=0) for sigma in sigma_factorized
215215
]
216-
z_samples, _ = self.pm.sample(means, sigmas_factorized_interleaved)
216+
# We want to use n_samples = 1 here, since we'll need to pass the interleaved means/sigmas to the log-likelihood
217+
z_samples = self.pm.sample(z_mean=means, sigma_factorized=sigmas_factorized_interleaved)
217218
log_qz = self.pm.log_likelihood(z_samples, means, sigmas_factorized_interleaved)
218219
log_pz = self.pm.log_likelihood(z_samples)
219220
return (log_qz - log_pz).view(-1, self.n_samples).mean(dim=1)

manify/manifolds.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,25 @@ def _to_tangent_plane_mu0(
211211

212212
def sample(
213213
self,
214-
z_mean: Float[torch.Tensor, "n_points n_ambient_dim"] | None = None,
214+
n_samples: int = 1,
215+
z_mean: Float[torch.Tensor, "n_points n_ambient_dim"] | Float[torch.Tensor, "n_ambient_dim"] | None = None,
215216
sigma: Float[torch.Tensor, "n_points n_dim n_dim"] | None = None,
216-
) -> tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points n_dim"]]:
217+
return_tangent: bool = False,
218+
) -> (
219+
tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points n_dim"]]
220+
| Float[torch.Tensor, "n_points n_ambient_dim"]
221+
):
217222
"""Sample points from the variational distribution on the manifold.
218223
219224
Args:
225+
n_samples: Number of points to sample.
220226
z_mean: Tensor representing the mean of the sample distribution.
221227
sigma: Optional tensor representing the covariance matrix. If None, defaults to an identity matrix.
228+
return_tangent: Whether to return the tangent vectors along with the sampled points.
222229
223230
Returns:
224231
x: Tensor of sampled points on the manifold
225-
v: Tensor of tangent vectors
232+
v: Tensor of tangent vectors (if `return_tangent` is True).
226233
"""
227234
z_mean = self.mu0 if z_mean is None else z_mean
228235
z_mean = torch.Tensor(z_mean).reshape(-1, self.ambient_dim).to(self.device)
@@ -237,6 +244,10 @@ def sample(
237244
assert torch.allclose(sigma, sigma.transpose(-1, -2)), "Covariance matrix must be symmetric"
238245
assert z_mean.shape[-1] == self.ambient_dim, f"Expected z_mean shape {self.ambient_dim}, got {z_mean.shape[-1]}"
239246

247+
# Adjust for n_points:
248+
z_mean = torch.repeat_interleave(z_mean, n_samples, dim=0)
249+
sigma = torch.repeat_interleave(sigma, n_samples, dim=0)
250+
240251
# Sample initial vector from N(0, sigma)
241252
N = torch.distributions.MultivariateNormal(
242253
loc=torch.zeros((n, self.dim), device=self.device), covariance_matrix=sigma
@@ -260,8 +271,7 @@ def sample(
260271
# Exp map onto the manifold
261272
x = self.manifold.expmap(x=z_mean, u=z)
262273

263-
# Different samples and tangent vectors
264-
return x, v
274+
return (x, v) if return_tangent else x
265275

266276
def log_likelihood(
267277
self,
@@ -611,19 +621,26 @@ def factorize(
611621

612622
def sample(
613623
self,
624+
n_samples: int = 1,
614625
z_mean: Float[torch.Tensor, "n_points n_ambient_dim"] | None = None,
615-
sigma_factorized: list[Float[torch.Tensor, "n_points ..."]] | None = None, # TODO: fix ... annotations
616-
) -> tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points total_intrinsic_dim"]]:
626+
sigma_factorized: list[Float[torch.Tensor, "n_points ..."]] | None = None,
627+
return_tangent: bool = False,
628+
) -> (
629+
tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points total_intrinsic_dim"]]
630+
| Float[torch.Tensor, "n_points n_ambient_dim"]
631+
):
617632
"""Sample from the variational distribution.
618633
619634
Args:
635+
n_samples: Number of points to sample.
620636
z_mean: Tensor representing the mean of the sample distribution. If None, defaults to the origin `self.mu0`.
621637
sigma_factorized: List of tensors representing factorized covariance matrices for each manifold. If None,
622638
defaults to a list of identity matrices for each manifold.
639+
return_tangent: Whether to return the tangent vectors along with the sampled points.
623640
624641
Returns:
625642
x: Tensor of sampled points on the manifold
626-
v: Tensor of tangent vectors
643+
v: Tensor of tangent vectors (if `return_tangent` is True).
627644
"""
628645
z_mean = self.mu0 if z_mean is None else z_mean
629646
z_mean = torch.Tensor(z_mean).reshape(-1, self.ambient_dim).to(self.device)
@@ -637,24 +654,28 @@ def sample(
637654
for M, sigma in zip(self.P, sigma_factorized, strict=False)
638655
]
639656

640-
assert all(sigma.shape == (n, M.dim, M.dim) for M, sigma in zip(self.P, sigma_factorized, strict=False)), (
641-
"Sigma matrices must match the dimensions of the manifolds."
642-
)
643-
assert z_mean.shape[-1] == self.ambient_dim, (
657+
# Adjust for n_points:
658+
z_mean = torch.repeat_interleave(z_mean, n_samples, dim=0)
659+
sigma_factorized = [torch.repeat_interleave(sigma, n_samples, dim=0) for sigma in sigma_factorized]
660+
661+
assert all(
662+
sigma.shape == (n * n_samples, M.dim, M.dim) for M, sigma in zip(self.P, sigma_factorized, strict=False)
663+
), "Sigma matrices must match the dimensions of the manifolds."
664+
assert z_mean.shape == (n * n_samples, self.ambient_dim), (
644665
"z_mean must have the same ambient dimension as the product manifold."
645666
)
646667

647668
# Sample initial vector from N(0, sigma)
648669
samples = [
649-
M.sample(z_M, sigma_M)
670+
M.sample(1, z_M, sigma_M, return_tangent=True)
650671
for M, z_M, sigma_M in zip(self.P, self.factorize(z_mean), sigma_factorized, strict=False)
651672
]
652673

653674
x = torch.cat([s[0] for s in samples], dim=1)
654675
v = torch.cat([s[1] for s in samples], dim=1)
655676

656677
# Different samples and tangent vectors
657-
return x, v
678+
return (x, v) if return_tangent else x
658679

659680
def log_likelihood(
660681
self,
@@ -807,15 +828,13 @@ def gaussian_mixture(
807828
cov_scale_means /= self.dim
808829

809830
# Generate cluster means
810-
cluster_means, _ = self.sample(
811-
z_mean=torch.vstack([self.mu0] * num_clusters),
812-
sigma_factorized=[torch.stack([torch.eye(M.dim)] * num_clusters) * cov_scale_means for M in self.P],
813-
)
831+
cluster_means = self.sample(num_clusters, sigma_factorized=[torch.eye(M.dim) * cov_scale_means for M in self.P])
814832
assert cluster_means.shape == (num_clusters, self.ambient_dim), "Cluster means shape mismatch."
815833

816834
# Generate class assignments
817835
cluster_probs = torch.rand(num_clusters)
818836
cluster_probs /= cluster_probs.sum()
837+
819838
# Draw cluster assignments: ensure at least 2 points per cluster. This is to ensure splits can always happen.
820839
cluster_assignments = torch.multinomial(input=cluster_probs, num_samples=num_points, replacement=True)
821840
while (cluster_assignments.bincount() < 2).any():
@@ -835,7 +854,7 @@ def gaussian_mixture(
835854
sample_means = torch.stack([cluster_means[c] for c in cluster_assignments])
836855
assert sample_means.shape == (num_points, self.ambient_dim), "Sample means shape mismatch."
837856
sample_covs = [torch.stack([cov_matrix[c] for c in cluster_assignments]) for cov_matrix in cov_matrices]
838-
samples, tangent_vals = self.sample(z_mean=sample_means, sigma_factorized=sample_covs)
857+
samples, tangent_vals = self.sample(z_mean=sample_means, sigma_factorized=sample_covs, return_tangent=True)
839858
assert samples.shape == (num_points, self.ambient_dim), "Sample shape mismatch."
840859

841860
# Map clusters to classes

manify/predictors/kappa_gcn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def fit(
214214
if use_tqdm:
215215
my_tqdm = tqdm(total=epochs, desc=tqdm_prefix)
216216

217+
losses = []
217218
for i in range(epochs):
218219
opt.zero_grad()
219220
if riemannian_params:
@@ -234,12 +235,13 @@ def fit(
234235
if torch.isnan(loss):
235236
print("Loss is NaN, stopping training.")
236237
break
238+
losses.append(loss.item())
237239

238240
if use_tqdm:
239241
my_tqdm.close()
240242

241243
self.is_fitted_ = True
242-
self.loss_history_["train"] = [loss.item()]
244+
self.loss_history_["train"] = losses
243245
return self
244246

245247
def predict_proba(

0 commit comments

Comments
 (0)