|
11 | 11 | from __future__ import annotations |
12 | 12 |
|
13 | 13 | import sys |
14 | | -from typing import Dict, List, Optional, Tuple |
15 | 14 |
|
16 | 15 | import numpy as np |
17 | 16 | import torch |
@@ -74,7 +73,7 @@ def __init__( |
74 | 73 | pm: ProductManifold, |
75 | 74 | encoder: torch.nn.Module, |
76 | 75 | decoder: torch.nn.Module, |
77 | | - random_state: Optional[int] = None, |
| 76 | + random_state: int | None = None, |
78 | 77 | device: str = "cpu", |
79 | 78 | beta: float = 1.0, |
80 | 79 | reconstruction_loss: torch.nn.modules.loss._Loss = torch.nn.MSELoss(reduction="none"), |
@@ -102,7 +101,7 @@ def __init__( |
102 | 101 |
|
103 | 102 | def encode( |
104 | 103 | self, x: Float[torch.Tensor, "batch_size n_features"] |
105 | | - ) -> Tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]: |
| 104 | + ) -> tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]: |
106 | 105 | r"""Encodes input data to obtain latent means and log-variances in the manifold. |
107 | 106 |
|
108 | 107 | This method processes input data through the encoder network to obtain parameters of the approximate posterior |
@@ -140,10 +139,10 @@ def decode(self, z: Float[torch.Tensor, "batch_size n_ambient"]) -> Float[torch. |
140 | 139 | """ |
141 | 140 | return self.decoder(z) |
142 | 141 |
|
143 | | - def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Tuple[ |
| 142 | + def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> tuple[ |
144 | 143 | Float[torch.Tensor, "batch_size n_features"], |
145 | 144 | Float[torch.Tensor, "batch_size n_ambient"], |
146 | | - List[Float[torch.Tensor, "n_latent n_latent"]], |
| 145 | + list[Float[torch.Tensor, "batch_size n_latent n_latent"]], |
147 | 146 | ]: |
148 | 147 | r"""Performs the forward pass of the VAE in product manifold space. |
149 | 148 |
|
@@ -181,7 +180,7 @@ def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Tuple[ |
181 | 180 | def kl_divergence( |
182 | 181 | self, |
183 | 182 | z_mean: Float[torch.Tensor, "batch_size n_latent"], |
184 | | - sigma_factorized: List[Float[torch.Tensor, "n_latent n_latent"]], |
| 183 | + sigma_factorized: list[Float[torch.Tensor, "batch_size manifold_dim manifold_dim"]], |
185 | 184 | ) -> Float[torch.Tensor, "batch_size"]: |
186 | 185 | r"""Computes the KL divergence between posterior and prior distributions in the manifold. |
187 | 186 |
|
@@ -214,7 +213,7 @@ def kl_divergence( |
214 | 213 |
|
215 | 214 | def elbo( |
216 | 215 | self, x: Float[torch.Tensor, "batch_size n_features"] |
217 | | - ) -> Tuple[Float[torch.Tensor, ""], Float[torch.Tensor, ""], Float[torch.Tensor, ""]]: |
| 216 | + ) -> tuple[Float[torch.Tensor, ""], Float[torch.Tensor, ""], Float[torch.Tensor, ""]]: |
218 | 217 | r"""Computes the Evidence Lower Bound (ELBO) for the VAE objective. |
219 | 218 |
|
220 | 219 | The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term |
|
0 commit comments