Skip to content

Commit 5144b72

Browse files
feat: diffMAE wrong when use_complex
1 parent fd0b101 commit 5144b72

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ def __init__(
199199
encoder_multipliers: Sequence[int],
200200
diffusion_type: str,
201201
stft_num_fft: int,
202+
stft_hop_length: int,
203+
stft_use_complex: bool,
204+
stft_window_length: Optional[int] = None,
202205
encoder_patch_size: int = 1,
203206
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
204207
bottleneck_channels: Optional[int] = None,
@@ -209,6 +212,7 @@ def __init__(
209212

210213
encoder_kwargs, kwargs = groupby("encoder_", kwargs)
211214
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
215+
stft_kwargs, kwargs = groupby("stft_", kwargs)
212216

213217
# Compute context channels
214218
context_channels = [0] * encoder_inject_depth
@@ -218,17 +222,26 @@ def __init__(
218222
context_channels += [encoder_channels * encoder_multipliers[-1]]
219223

220224
self.spectrogram_channels = stft_num_fft // 2 + 1
225+
self.stft_hop_length = stft_hop_length
226+
227+
self.encoder_stft = STFT(
228+
num_fft=stft_num_fft,
229+
hop_length=stft_hop_length,
230+
window_length=stft_window_length,
231+
use_complex=False, # Magnitude encoding
232+
)
221233

222234
self.unet = UNet1d(
223235
in_channels=in_channels,
224-
stft_num_fft=stft_num_fft,
225236
context_channels=context_channels,
226237
use_stft=True,
238+
stft_use_complex=stft_use_complex,
239+
stft_num_fft=stft_num_fft,
240+
stft_hop_length=stft_hop_length,
241+
stft_window_length=stft_window_length,
227242
**kwargs,
228243
)
229244

230-
self.stft = self.unet.stft
231-
232245
self.diffusion = XDiffusion(
233246
type=diffusion_type, net=self.unet, **diffusion_kwargs
234247
)
@@ -251,7 +264,7 @@ def encode(
251264
self, x: Tensor, with_info: bool = False
252265
) -> Union[Tensor, Tuple[Tensor, Any]]:
253266
# Extract magnitude and encode
254-
magnitude, _ = self.stft.encode(x)
267+
magnitude, _ = self.encoder_stft.encode(x)
255268
magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t")
256269
latent, info = self.encoder(magnitude_flat, with_info=True)
257270
# Apply bottlenecks if present
@@ -270,7 +283,7 @@ def forward( # type: ignore
270283
def decode(self, latent: Tensor, **kwargs) -> Tensor:
271284
b = latent.shape[0]
272285
length = closest_power_2(
273-
self.stft.hop_length * latent.shape[2] * self.encoder_downsample_factor
286+
self.stft_hop_length * latent.shape[2] * self.encoder_downsample_factor
274287
)
275288
# Compute noise by inferring shape from latent length
276289
noise = torch.randn(b, self.in_channels, length, device=latent.device)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.89",
6+
version="0.0.90",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)