Skip to content

Commit 8b048b3

Browse files
committed
Add siamese neural net; qiita dataset
1 parent 9adac2c commit 8b048b3

File tree

10 files changed

+560
-3206
lines changed

10 files changed

+560
-3206
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,26 @@ jobs:
4141
# Code quality checks
4242
- name: Check code formatting with Black
4343
run: black --check manify/ --line-length 120
44-
continue-on-error: true
4544

4645
- name: Check import ordering with isort
4746
run: isort --check-only --profile black manify/ --line-width 120
48-
continue-on-error: true
4947

5048
- name: Run pylint
5149
run: pylint manify/
52-
continue-on-error: true
5350

5451
# Type checking
5552
- name: Check type annotations with MyPy
5653
run: mypy manify/
57-
continue-on-error: true
5854

5955
# Unit testing
6056
- name: Run unit tests & collect coverage
6157
run: pytest tests --cov=manify --cov-report=xml:coverage.xml
6258

59+
# Check docstrings are in Google style
60+
- name: Check docstrings are in Google style
61+
run: pydocstyle manify/ --convention=google
62+
continue-on-error: true
63+
6364
# Code coverage
6465
- name: Upload coverage to Codecov
6566
uses: codecov/codecov-action@v5

manify/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Manify: A Python Library for Learning Non-Euclidean Representations."""
22

33
from manify.curvature_estimation import (
4-
sampled_delta_hyperbolicity,
54
delta_hyperbolicity,
6-
sectional_curvature,
75
greedy_signature_selection,
6+
sampled_delta_hyperbolicity,
7+
sectional_curvature,
88
)
99
from manify.embedders import CoordinateLearning, ProductSpaceVAE, SiameseNetwork
1010
from manify.manifolds import Manifold, ProductManifold
11-
from manify.predictors import ProductSpaceDT, ProductSpaceRF, KappaGCN, ProductSpacePerceptron, ProductSpaceSVM
11+
from manify.predictors import KappaGCN, ProductSpaceDT, ProductSpacePerceptron, ProductSpaceRF, ProductSpaceSVM
1212

1313
# import manify.utils
1414

manify/curvature_estimation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
* `sectional_curvature`: Estimates the sectional curvature of a graph from its distance matrix.
88
"""
99

10-
from manify.curvature_estimation.delta_hyperbolicity import sampled_delta_hyperbolicity, delta_hyperbolicity
10+
from manify.curvature_estimation.delta_hyperbolicity import delta_hyperbolicity, sampled_delta_hyperbolicity
1111
from manify.curvature_estimation.greedy_method import greedy_signature_selection
1212
from manify.curvature_estimation.sectional_curvature import sectional_curvature
1313

manify/embedders/siamese.py

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class SiameseNetwork(BaseEmbedder, torch.nn.Module):
4242
beta: Weight for the distortion term in the loss function.
4343
device: Device for tensor computations.
4444
reconstruction_loss: Type of reconstruction loss to use.
45-
45+
4646
4747
Args:
4848
pm: Product manifold defining the structure of the latent space.
@@ -61,10 +61,18 @@ def __init__(
6161
encoder: torch.nn.Module,
6262
decoder: Optional[torch.nn.Module] = None,
6363
reconstruction_loss: str = "mse",
64+
beta: float = 1.0,
65+
random_state: Optional[int] = None,
66+
device: str = "cpu",
6467
):
65-
super().__init__()
68+
# Init both base classes
69+
torch.nn.Module.__init__(self)
70+
BaseEmbedder.__init__(self, pm=pm, random_state=random_state, device=device)
71+
72+
# Now we assign
6673
self.pm = pm
6774
self.encoder = encoder
75+
self.beta = beta
6876

6977
if decoder is not None:
7078
self.decoder = decoder
@@ -104,3 +112,184 @@ def decode(self, z: Float[torch.Tensor, "batch_size n_latent"]) -> Float[torch.T
104112
reconstructed: Tensor containing the reconstructed input data.
105113
"""
106114
return self.decoder(z)
115+
116+
def forward(
117+
self, x1: Float[torch.Tensor, "batch_size n_features"], x2: Float[torch.Tensor, "batch_size n_features"]
118+
) -> Tuple[
119+
Float[torch.Tensor, "batch_size n_latent"],
120+
Float[torch.Tensor, "batch_size n_latent"],
121+
Float[torch.Tensor, "batch_size,"],
122+
Float[torch.Tensor, "batch_size n_features"],
123+
Float[torch.Tensor, "batch_size n_features"],
124+
]:
125+
"""Given two points, return their encodings, reconstructions, and embedding distance.
126+
127+
Args:
128+
x1: First input tensor.
129+
x2: Second input tensor.
130+
131+
Returns:
132+
z1: Encoded representation of the first input.
133+
z2: Encoded representation of the second input.
134+
D_hat: Estimated distance between the two embeddings.
135+
reconstructed1: Reconstructed input from the first embedding.
136+
reconstructed2: Reconstructed input from the second embedding.
137+
"""
138+
z1 = self.pm.expmap(self.encode(x1) @ self.pm.projection_matrix)
139+
z2 = self.pm.expmap(self.encode(x2) @ self.pm.projection_matrix)
140+
D_hat = self.pm.manifold.dist(z1, z2) # use manifold dist to get (batch_size, ) vector of dists
141+
reconstructed1 = self.decode(z1)
142+
reconstructed2 = self.decode(z2)
143+
return z1, z2, D_hat, reconstructed1, reconstructed2
144+
145+
def fit( # type: ignore[override]
146+
self,
147+
X: Float[torch.Tensor, "n_points n_features"],
148+
D: Float[torch.Tensor, "n_points n_points"],
149+
lr: float = 1e-3,
150+
burn_in_lr: float = 1e-4,
151+
curvature_lr: float = 0.0, # Off by default
152+
burn_in_iterations: int = 1,
153+
training_iterations: int = 9,
154+
loss_window_size: int = 100,
155+
logging_interval: int = 10,
156+
batch_size: int = 32,
157+
clip_grad: bool = True,
158+
) -> "SiameseNetwork":
159+
"""Fit the SiameseNetwork embedder.
160+
161+
Args:
162+
X: Input data features to encode.
163+
D: Pairwise distances to emulate.
164+
lr: Learning rate for the optimizer.
165+
burn_in_lr: Learning rate during burn-in phase.
166+
curvature_lr: Learning rate for curvature updates.
167+
burn_in_iterations: Number of iterations for burn-in phase.
168+
training_iterations: Number of iterations for training phase.
169+
loss_window_size: Size of the window for loss averaging.
170+
logging_interval: Interval for logging progress.
171+
batch_size: Number of samples per batch.
172+
clip_grad: Whether to clip gradients.
173+
174+
Returns:
175+
self: Fitted SiameseNetwork instance.
176+
"""
177+
if self.random_state is not None:
178+
torch.manual_seed(self.random_state)
179+
180+
n_samples = len(X)
181+
182+
# Generate all upper triangular pairs using torch
183+
indices = torch.triu_indices(n_samples, n_samples, offset=1)
184+
pairs = torch.hstack([indices]).T # (n_pairs, 2)
185+
186+
# Number of pairs and batches
187+
n_pairs = len(pairs)
188+
n_batches_per_epoch = (n_pairs + batch_size - 1) // batch_size # Ceiling division
189+
total_iterations = (burn_in_iterations + training_iterations) * n_batches_per_epoch
190+
191+
my_tqdm = tqdm(total=total_iterations)
192+
193+
opt = torch.optim.Adam(
194+
[
195+
{"params": [p for p in self.parameters() if p not in set(self.pm.parameters())], "lr": burn_in_lr},
196+
{"params": self.pm.parameters(), "lr": 0},
197+
]
198+
)
199+
losses: Dict[str, List[float]] = {"total": [], "reconstruction": [], "distortion": []}
200+
201+
for epoch in range(burn_in_iterations + training_iterations):
202+
if epoch == burn_in_iterations:
203+
opt.param_groups[0]["lr"] = lr
204+
opt.param_groups[1]["lr"] = curvature_lr
205+
206+
# Shuffle all pairs
207+
shuffle_idx = torch.randperm(n_pairs)
208+
shuffled_pairs = pairs[shuffle_idx]
209+
210+
for batch_start in range(0, n_pairs, batch_size):
211+
batch_end = min(batch_start + batch_size, n_pairs)
212+
batch_pairs = shuffled_pairs[batch_start:batch_end]
213+
214+
# Extract indices for this batch
215+
batch_indices1 = batch_pairs[:, 0]
216+
batch_indices2 = batch_pairs[:, 1]
217+
218+
# Get data for these indices
219+
X1 = X[batch_indices1]
220+
X2 = X[batch_indices2]
221+
222+
# Extract the corresponding distances from D using advanced indexing
223+
D_batch = D[batch_indices1, batch_indices2]
224+
225+
# Forward pass
226+
opt.zero_grad()
227+
_, _, D_hat, Y1, Y2 = self(X1, X2)
228+
mse1 = torch.nn.functional.mse_loss(Y1, X1)
229+
mse2 = torch.nn.functional.mse_loss(Y2, X2)
230+
231+
# D_hat and D_batch are now 1D tensors of pairwise distances
232+
distortion = distortion_loss(D_hat, D_batch, pairwise=False)
233+
L = mse1 + mse2 + self.beta * distortion
234+
L.backward()
235+
236+
# Add to losses
237+
losses["total"].append(L.item())
238+
losses["reconstruction"].append(mse1.item() + mse2.item())
239+
losses["distortion"].append(distortion.item())
240+
241+
if clip_grad:
242+
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
243+
torch.nn.utils.clip_grad_norm_(self.pm.parameters(), max_norm=1.0)
244+
245+
opt.step()
246+
247+
# TQDM management
248+
my_tqdm.update(1)
249+
my_tqdm.set_description(
250+
f"L: {L.item():.3e}, recon: {mse1.item() + mse2.item():.3e}, dist: {distortion.item():.3e}"
251+
)
252+
253+
# Logging
254+
if my_tqdm.n % logging_interval == 0:
255+
d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
256+
d["L_avg"] = f"{np.mean(losses['total'][-loss_window_size:]):.3e}"
257+
d["recon_avg"] = f"{np.mean(losses['reconstruction'][-loss_window_size:]):.3e}"
258+
d["dist_avg"] = f"{np.mean(losses['distortion'][-loss_window_size:]):.3e}"
259+
my_tqdm.set_postfix(d)
260+
261+
# Final maintenance: update attributes
262+
self.loss_history_ = losses
263+
self.is_fitted_ = True
264+
265+
return self
266+
267+
def transform(
268+
self, X: Float[torch.Tensor, "n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
269+
) -> Float[torch.Tensor, "n_points n_latent"]:
270+
"""Transforms input data into manifold embeddings.
271+
272+
Args:
273+
X: Features to embed with SiameseNetwork.
274+
D: Ignored.
275+
batch_size: Number of samples per batch.
276+
expmap: Whether to use exponential map for embedding.
277+
278+
Returns:
279+
embeddings: Embeddings produced by forward pass of trained SiameseNetwork model.
280+
"""
281+
# Set random state
282+
if self.random_state is not None:
283+
torch.manual_seed(self.random_state)
284+
285+
# Save the embeddings
286+
embeddings_list = []
287+
for i in range(0, len(X), batch_size):
288+
batch = X[i : i + batch_size]
289+
embeddings = self.encode(batch)
290+
if expmap:
291+
embeddings = self.pm.expmap(embeddings @ self.pm.projection_matrix)
292+
embeddings_list.append(embeddings)
293+
embeddings = torch.cat(embeddings_list, dim=0)
294+
295+
return embeddings

manify/embedders/vae.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,10 @@ def fit( # type: ignore[override]
306306

307307
my_tqdm = tqdm(total=(burn_in_iterations + training_iterations) * len(X))
308308
opt = torch.optim.Adam(
309-
[{"params": self.parameters(), "lr": burn_in_lr}, {"params": self.pm.parameters(), "lr": 0}]
309+
[
310+
{"params": [p for p in self.parameters() if p not in set(self.pm.parameters())], "lr": burn_in_lr},
311+
{"params": self.pm.parameters(), "lr": 0},
312+
]
310313
)
311314
losses: Dict[str, List[float]] = {"elbo": [], "ll": [], "kl": []}
312315
for epoch in range(burn_in_iterations + training_iterations):
@@ -319,12 +322,12 @@ def fit( # type: ignore[override]
319322
X_batch = X[i : i + batch_size]
320323
elbo, ll, kl = self.elbo(X_batch)
321324
L = -elbo
325+
L.backward()
322326

323327
# Add to losses
324328
losses["elbo"].append(elbo.item())
325329
losses["ll"].append(ll.item())
326330
losses["kl"].append(kl.item())
327-
L.backward()
328331

329332
if clip_grad:
330333
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
@@ -362,6 +365,7 @@ def transform(
362365
Args:
363366
X: Features to embed with VAE.
364367
D: Ignored.
368+
batch_size: Number of samples per batch.
365369
expmap: Whether to use exponential map for embedding.
366370
367371
Returns:

manify/utils/dataloaders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
| neuron_33 | classification | ❌ | ✅ | ✅ | ❌ | [Allen Brain Atlas](https://celltypes.brain-map.org/experiment/electrophysiology/623474400) |
2929
| neuron_46 | classification | ❌ | ✅ | ✅ | ❌ | [Allen Brain Atlas](https://celltypes.brain-map.org/experiment/electrophysiology/623474400) |
3030
| traffic | regression | ❌ | ✅ | ✅ | ❌ | [Kaggle: Traffic Prediction Dataset](https://www.kaggle.com/datasets/fedesoriano/traffic-prediction-dataset) |
31+
| qiita | none | ✅ | ✅ | ❌ | ❌ | [NeuroSEED Git Repo](https://github.com/gcorso/NeuroSEED) |
3132
"""
3233

3334
from __future__ import annotations

0 commit comments

Comments
 (0)