Skip to content

Commit ef6d57a

Browse files
committed
Delete dataloaders_old; fix typing in vae.py
1 parent 852053e commit ef6d57a

File tree

2 files changed

+3
-519
lines changed

2 files changed

+3
-519
lines changed

manify/embedders/vae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import sys
6-
from typing import List, Tuple
6+
from typing import List, Tuple, Optional
77

88
import torch
99
from jaxtyping import Float
@@ -127,7 +127,7 @@ def elbo(
127127
ll = -self.reconstruction_loss(x_reconstructed.view(x.shape[0], -1), x.view(x.shape[0], -1)).sum(dim=1)
128128
return (ll - self.beta * kld).mean(), ll.mean(), kld.mean()
129129

130-
def _grads_ok(self):
130+
def _grads_ok(self) -> bool:
131131
out = True
132132
for name, param in self.named_parameters():
133133
if param.grad is not None:
@@ -145,7 +145,7 @@ def fit(
145145
burn_in_epochs: int = 100,
146146
epochs: int = 1900,
147147
batch_size: int = 32,
148-
seed: int = None,
148+
seed: Optional[int] = None,
149149
lr: float = 1e-3,
150150
curvature_lr: float = 1e-4,
151151
clip_grad: bool = True,

0 commit comments

Comments
 (0)