Skip to content

Commit d582296

Browse files
feat: add with_info during forward in autoencoder
1 parent f98c25e commit d582296

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,17 @@ def __init__(
191191
num_groups=resnet_groups,
192192
)
193193

194-
def forward(self, x: Tensor, **kwargs) -> Tensor:
195-
latent = self.encode(x)
194+
def forward( # type: ignore
195+
self, x: Tensor, with_info: bool = False, **kwargs
196+
) -> Union[Tensor, Tuple[Tensor, Any]]:
197+
if with_info:
198+
latent, info = self.encode(x, with_info=True)
199+
else:
200+
latent = self.encode(x)
201+
196202
context = self.to_context(latent)
197-
return self.diffusion(x, context=[context], **kwargs)
203+
loss = self.diffusion(x, context=[context], **kwargs)
204+
return (loss, info) if with_info else loss
198205

199206
def encode(
200207
self, x: Tensor, with_info: bool = False

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.28",
6+
version="0.0.29",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)