@@ -42,7 +42,7 @@ class SiameseNetwork(BaseEmbedder, torch.nn.Module):
4242 beta: Weight for the distortion term in the loss function.
4343 device: Device for tensor computations.
4444 reconstruction_loss: Type of reconstruction loss to use.
45-
45+
4646
4747 Args:
4848 pm: Product manifold defining the structure of the latent space.
@@ -61,10 +61,18 @@ def __init__(
6161 encoder : torch .nn .Module ,
6262 decoder : Optional [torch .nn .Module ] = None ,
6363 reconstruction_loss : str = "mse" ,
64+ beta : float = 1.0 ,
65+ random_state : Optional [int ] = None ,
66+ device : str = "cpu" ,
6467 ):
65- super ().__init__ ()
68+ # Init both base classes
69+ torch .nn .Module .__init__ (self )
70+ BaseEmbedder .__init__ (self , pm = pm , random_state = random_state , device = device )
71+
72+ # Now we assign
6673 self .pm = pm
6774 self .encoder = encoder
75+ self .beta = beta
6876
6977 if decoder is not None :
7078 self .decoder = decoder
@@ -104,3 +112,184 @@ def decode(self, z: Float[torch.Tensor, "batch_size n_latent"]) -> Float[torch.T
104112 reconstructed: Tensor containing the reconstructed input data.
105113 """
106114 return self .decoder (z )
115+
116+ def forward (
117+ self , x1 : Float [torch .Tensor , "batch_size n_features" ], x2 : Float [torch .Tensor , "batch_size n_features" ]
118+ ) -> Tuple [
119+ Float [torch .Tensor , "batch_size n_latent" ],
120+ Float [torch .Tensor , "batch_size n_latent" ],
121+ Float [torch .Tensor , "batch_size," ],
122+ Float [torch .Tensor , "batch_size n_features" ],
123+ Float [torch .Tensor , "batch_size n_features" ],
124+ ]:
125+ """Given two points, return their encodings, reconstructions, and embedding distance.
126+
127+ Args:
128+ x1: First input tensor.
129+ x2: Second input tensor.
130+
131+ Returns:
132+ z1: Encoded representation of the first input.
133+ z2: Encoded representation of the second input.
134+ D_hat: Estimated distance between the two embeddings.
135+ reconstructed1: Reconstructed input from the first embedding.
136+ reconstructed2: Reconstructed input from the second embedding.
137+ """
138+ z1 = self .pm .expmap (self .encode (x1 ) @ self .pm .projection_matrix )
139+ z2 = self .pm .expmap (self .encode (x2 ) @ self .pm .projection_matrix )
140+ D_hat = self .pm .manifold .dist (z1 , z2 ) # use manifold dist to get (batch_size, ) vector of dists
141+ reconstructed1 = self .decode (z1 )
142+ reconstructed2 = self .decode (z2 )
143+ return z1 , z2 , D_hat , reconstructed1 , reconstructed2
144+
145+ def fit ( # type: ignore[override]
146+ self ,
147+ X : Float [torch .Tensor , "n_points n_features" ],
148+ D : Float [torch .Tensor , "n_points n_points" ],
149+ lr : float = 1e-3 ,
150+ burn_in_lr : float = 1e-4 ,
151+ curvature_lr : float = 0.0 , # Off by default
152+ burn_in_iterations : int = 1 ,
153+ training_iterations : int = 9 ,
154+ loss_window_size : int = 100 ,
155+ logging_interval : int = 10 ,
156+ batch_size : int = 32 ,
157+ clip_grad : bool = True ,
158+ ) -> "SiameseNetwork" :
159+ """Fit the SiameseNetwork embedder.
160+
161+ Args:
162+ X: Input data features to encode.
163+ D: Pairwise distances to emulate.
164+ lr: Learning rate for the optimizer.
165+ burn_in_lr: Learning rate during burn-in phase.
166+ curvature_lr: Learning rate for curvature updates.
167+ burn_in_iterations: Number of iterations for burn-in phase.
168+ training_iterations: Number of iterations for training phase.
169+ loss_window_size: Size of the window for loss averaging.
170+ logging_interval: Interval for logging progress.
171+ batch_size: Number of samples per batch.
172+ clip_grad: Whether to clip gradients.
173+
174+ Returns:
175+ self: Fitted SiameseNetwork instance.
176+ """
177+ if self .random_state is not None :
178+ torch .manual_seed (self .random_state )
179+
180+ n_samples = len (X )
181+
182+ # Generate all upper triangular pairs using torch
183+ indices = torch .triu_indices (n_samples , n_samples , offset = 1 )
184+ pairs = torch .hstack ([indices ]).T # (n_pairs, 2)
185+
186+ # Number of pairs and batches
187+ n_pairs = len (pairs )
188+ n_batches_per_epoch = (n_pairs + batch_size - 1 ) // batch_size # Ceiling division
189+ total_iterations = (burn_in_iterations + training_iterations ) * n_batches_per_epoch
190+
191+ my_tqdm = tqdm (total = total_iterations )
192+
193+ opt = torch .optim .Adam (
194+ [
195+ {"params" : [p for p in self .parameters () if p not in set (self .pm .parameters ())], "lr" : burn_in_lr },
196+ {"params" : self .pm .parameters (), "lr" : 0 },
197+ ]
198+ )
199+ losses : Dict [str , List [float ]] = {"total" : [], "reconstruction" : [], "distortion" : []}
200+
201+ for epoch in range (burn_in_iterations + training_iterations ):
202+ if epoch == burn_in_iterations :
203+ opt .param_groups [0 ]["lr" ] = lr
204+ opt .param_groups [1 ]["lr" ] = curvature_lr
205+
206+ # Shuffle all pairs
207+ shuffle_idx = torch .randperm (n_pairs )
208+ shuffled_pairs = pairs [shuffle_idx ]
209+
210+ for batch_start in range (0 , n_pairs , batch_size ):
211+ batch_end = min (batch_start + batch_size , n_pairs )
212+ batch_pairs = shuffled_pairs [batch_start :batch_end ]
213+
214+ # Extract indices for this batch
215+ batch_indices1 = batch_pairs [:, 0 ]
216+ batch_indices2 = batch_pairs [:, 1 ]
217+
218+ # Get data for these indices
219+ X1 = X [batch_indices1 ]
220+ X2 = X [batch_indices2 ]
221+
222+ # Extract the corresponding distances from D using advanced indexing
223+ D_batch = D [batch_indices1 , batch_indices2 ]
224+
225+ # Forward pass
226+ opt .zero_grad ()
227+ _ , _ , D_hat , Y1 , Y2 = self (X1 , X2 )
228+ mse1 = torch .nn .functional .mse_loss (Y1 , X1 )
229+ mse2 = torch .nn .functional .mse_loss (Y2 , X2 )
230+
231+ # D_hat and D_batch are now 1D tensors of pairwise distances
232+ distortion = distortion_loss (D_hat , D_batch , pairwise = False )
233+ L = mse1 + mse2 + self .beta * distortion
234+ L .backward ()
235+
236+ # Add to losses
237+ losses ["total" ].append (L .item ())
238+ losses ["reconstruction" ].append (mse1 .item () + mse2 .item ())
239+ losses ["distortion" ].append (distortion .item ())
240+
241+ if clip_grad :
242+ torch .nn .utils .clip_grad_norm_ (self .parameters (), max_norm = 1.0 )
243+ torch .nn .utils .clip_grad_norm_ (self .pm .parameters (), max_norm = 1.0 )
244+
245+ opt .step ()
246+
247+ # TQDM management
248+ my_tqdm .update (1 )
249+ my_tqdm .set_description (
250+ f"L: { L .item ():.3e} , recon: { mse1 .item () + mse2 .item ():.3e} , dist: { distortion .item ():.3e} "
251+ )
252+
253+ # Logging
254+ if my_tqdm .n % logging_interval == 0 :
255+ d = {f"r{ i } " : f"{ logscale .item ():.3f} " for i , logscale in enumerate (self .pm .parameters ())}
256+ d ["L_avg" ] = f"{ np .mean (losses ['total' ][- loss_window_size :]):.3e} "
257+ d ["recon_avg" ] = f"{ np .mean (losses ['reconstruction' ][- loss_window_size :]):.3e} "
258+ d ["dist_avg" ] = f"{ np .mean (losses ['distortion' ][- loss_window_size :]):.3e} "
259+ my_tqdm .set_postfix (d )
260+
261+ # Final maintenance: update attributes
262+ self .loss_history_ = losses
263+ self .is_fitted_ = True
264+
265+ return self
266+
267+ def transform (
268+ self , X : Float [torch .Tensor , "n_points n_features" ], D : None = None , batch_size : int = 32 , expmap : bool = True
269+ ) -> Float [torch .Tensor , "n_points n_latent" ]:
270+ """Transforms input data into manifold embeddings.
271+
272+ Args:
273+ X: Features to embed with SiameseNetwork.
274+ D: Ignored.
275+ batch_size: Number of samples per batch.
276+ expmap: Whether to use exponential map for embedding.
277+
278+ Returns:
279+ embeddings: Embeddings produced by forward pass of trained SiameseNetwork model.
280+ """
281+ # Set random state
282+ if self .random_state is not None :
283+ torch .manual_seed (self .random_state )
284+
285+ # Save the embeddings
286+ embeddings_list = []
287+ for i in range (0 , len (X ), batch_size ):
288+ batch = X [i : i + batch_size ]
289+ embeddings = self .encode (batch )
290+ if expmap :
291+ embeddings = self .pm .expmap (embeddings @ self .pm .projection_matrix )
292+ embeddings_list .append (embeddings )
293+ embeddings = torch .cat (embeddings_list , dim = 0 )
294+
295+ return embeddings
0 commit comments