Skip to content

Commit c458fc2

Browse files
feat: add magnitude, phase info
1 parent fc1fb9b commit c458fc2

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,9 +1777,14 @@ def encode(
17771777
log_magnitude = rearrange(torch.log(magnitude), "b c f t -> b (c f) t")
17781778
return super().encode(log_magnitude, with_info)
17791779

1780-
def decode(self, z: Tensor) -> Tensor:
1780+
def decode( # type: ignore
1781+
self, z: Tensor, with_info: bool = False
1782+
) -> Union[Tensor, Tuple[Tensor, Any]]:
17811783
f = self.frequency_channels
17821784
stft = super().decode(z)
17831785
stft = rearrange(stft, "b (c f i) t -> b (c i) f t", i=2, f=f)
17841786
log_magnitude, phase = stft.chunk(chunks=2, dim=1)
1785-
return self.stft.decode(magnitude=torch.exp(log_magnitude), phase=phase)
1787+
magnitude = torch.exp(log_magnitude)
1788+
wave = self.stft.decode(magnitude, phase)
1789+
info = dict(magnitude=magnitude, phase=phase)
1790+
return (wave, info) if with_info else wave

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.77",
6+
version="0.0.78",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)