Skip to content

Commit b3baf4f

Browse files
committed
Start delta-hyperbolicity exploration
1 parent 6a28235 commit b3baf4f

File tree

5 files changed

+3385
-20
lines changed

5 files changed

+3385
-20
lines changed

manify/embedders/_losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def distortion_loss(
5151
D_true = D_true.flatten()
5252
D_est = D_est.flatten()
5353

54-
# Mask out any infinite or nan values
55-
mask = torch.isfinite(D_true) & ~torch.isnan(D_true)
54+
# Mask out any infinite or nan values; also anywhere the true distance is zero
55+
mask = torch.isfinite(D_true) & ~torch.isnan(D_true) & (D_true != 0)
5656
D_true = D_true[mask]
5757
D_est = D_est[mask]
5858

manify/embedders/siamese.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,49 @@
1010

1111
from __future__ import annotations
1212

13-
from typing import Optional
13+
import sys
14+
from typing import Dict, List, Optional, Tuple
1415

16+
import numpy as np
1517
import torch
1618
from jaxtyping import Float
1719

1820
from ..manifolds import ProductManifold
1921
from ._base import BaseEmbedder
22+
from ._losses import distortion_loss
2023

24+
# TQDM: notebook or regular
25+
if "ipykernel" in sys.modules:
26+
from tqdm.notebook import tqdm
27+
else:
28+
from tqdm import tqdm
2129

22-
class SiameseNetwork(torch.nn.Module):
30+
31+
class SiameseNetwork(BaseEmbedder, torch.nn.Module):
2332
"""Siamese network for embedding data into a product manifold space.
2433
2534
A Siamese network consists of an encoder network that maps input data to a latent representation in a product
2635
manifold, and optionally a decoder network that maps the latent representation back to the original feature space.
2736
2837
Attributes:
29-
pm: The product manifold object defining the embedding space.
30-
encoder: Neural network module that maps input data to the embedding space.
31-
decoder: Optional neural network module for reconstructing input data from embeddings.
32-
reconstruction_loss: Loss function for measuring reconstruction quality.
38+
pm: Product manifold defining the structure of the latent space.
39+
random_state: Random state for reproducibility.
40+
encoder: Neural network that maps inputs to latent embeddings.
41+
decoder: Neural network that reconstructs inputs from latent embeddings.
42+
beta: Weight for the distortion term in the loss function.
43+
device: Device for tensor computations.
44+
reconstruction_loss: Type of reconstruction loss to use.
45+
3346
3447
Args:
35-
pm: Product manifold object defining the target embedding space.
36-
encoder: Neural network module that maps inputs to the embedding space.
37-
decoder: Optional neural network module that maps embeddings back to input space. If None, a no-op identity
38-
module is used.
39-
reconstruction_loss: Type of reconstruction loss to use. Currently only "mse" (mean squared error) is supported.
40-
41-
Raises:
42-
ValueError: If an unsupported reconstruction_loss is specified.
48+
pm: Product manifold defining the structure of the latent space.
49+
encoder: Neural network module that maps inputs to the manifold's intrinsic dimension.
50+
The output dimension should match the intrinsic dimension of the product manifold.
51+
decoder: Neural network module that maps latent representations back to the input space.
52+
random_state: Optional random state for reproducibility.
53+
device: Optional device for tensor computations.
54+
beta: Weight of the distortion term in the loss function.
55+
reconstruction_loss: Type of reconstruction loss to use.
4356
"""
4457

4558
def __init__(

manify/embedders/vae.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class ProductSpaceVAE(BaseEmbedder, torch.nn.Module):
5151
device: Device for tensor computations.
5252
n_samples: Number of samples for Monte Carlo estimation of KL divergence.
5353
reconstruction_loss: Type of reconstruction loss to use.
54+
loss_history_: Dictionary to store the history of loss values during training.
55+
is_fitted_: Boolean flag indicating whether the model has been fitted.
56+
5457
5558
Args:
5659
pm: Product manifold defining the structure of the latent space.

0 commit comments

Comments
 (0)