@@ -262,8 +262,8 @@ def fit( # type: ignore[override]
262262 self ,
263263 X : Float [torch .Tensor , "n_points n_features" ],
264264 D : None = None ,
265- lr : float = 1e-2 ,
266- burn_in_lr : float = 1e-3 ,
265+ lr : float = 1e-3 ,
266+ burn_in_lr : float = 1e-4 ,
267267 curvature_lr : float = 0.0 , # Off by default
268268 burn_in_iterations : int = 1 ,
269269 training_iterations : int = 9 ,
@@ -303,7 +303,7 @@ def fit( # type: ignore[override]
303303
304304 my_tqdm = tqdm (total = (burn_in_iterations + training_iterations ) * len (X ))
305305 opt = torch .optim .Adam (
306- [{"params" : self .parameters (), "lr" : lr * 0.1 }, {"params" : self .pm .parameters (), "lr" : curvature_lr }]
306+ [{"params" : self .parameters (), "lr" : burn_in_lr }, {"params" : self .pm .parameters (), "lr" : 0 }]
307307 )
308308 losses : Dict [str , List [float ]] = {"elbo" : [], "ll" : [], "kl" : []}
309309 for epoch in range (burn_in_iterations + training_iterations ):
@@ -352,13 +352,14 @@ def fit( # type: ignore[override]
352352 return self
353353
354354 def transform (
355- self , X : Float ["n_points n_features" ], D : None = None , batch_size : int = 32
355+ self , X : Float ["n_points n_features" ], D : None = None , batch_size : int = 32 , expmap : bool = True
356356 ) -> Float ["n_points embedding_dim" ]:
357357 """Transform data using the trained VAE. Outputs means of the variational distribution.
358358
359359 Args:
360360 X: Features to embed with VAE.
361361 D: Ignored.
362+ expmap: Whether to use exponential map for embedding.
362363
363364 Returns:
364365 embeddings: Learned embeddings.
@@ -372,7 +373,12 @@ def transform(
372373 embeddings_list = []
373374 for i in range (0 , len (X ), batch_size ):
374375 x_batch = X [i : i + batch_size ]
375- z_mean , _ = self .encode (x_batch )
376+ z_mean_tangent , _ = self .encode (x_batch )
377+ if expmap :
378+ z_mean_ambient = z_mean_tangent @ self .pm .projection_matrix # Adds zeros in the right places
379+ z_mean = self .pm .expmap (u = z_mean_ambient , base = None )
380+ else :
381+ z_mean = z_mean_tangent
376382 embeddings_list .append (z_mean .detach ().cpu ())
377383
378384 embeddings = torch .cat (embeddings_list , dim = 0 )
0 commit comments