Skip to content

Commit 58db626

Browse files
feat: add mean and logvar info to variational bottleneck
1 parent 7fabeb3 commit 58db626

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,7 @@ def forward(
12111211
logvar = torch.clamp(logvar, -30.0, 20.0)
12121212
out = gaussian_sample(mean, logvar)
12131213
loss = kl_loss(mean, logvar) * self.loss_weight
1214-
return (out, dict(loss=loss)) if with_info else out
1214+
return (out, dict(loss=loss, mean=mean, logvar=logvar)) if with_info else out
12151215

12161216

12171217
class AutoEncoder1d(nn.Module):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.64",
6+
version="0.0.65",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)