Skip to content

Commit 425e559

Browse files
feat: context channels in model1d
1 parent 19ec856 commit 425e559

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
diffusion_sigma_data: int,
4141
diffusion_dynamic_threshold: float,
4242
out_channels: Optional[int] = None,
43+
context_channels: Optional[Sequence[int]] = None,
4344
use_autoencoder: bool = False,
4445
autoencoder: Optional[AutoEncoder1d] = None,
4546
autoencoder_scale: float = 1.0,
@@ -72,6 +73,7 @@ def __init__(
7273
use_skip_scale=use_skip_scale,
7374
use_attention_bottleneck=use_attention_bottleneck,
7475
out_channels=out_channels,
76+
context_channels=context_channels,
7577
)
7678

7779
self.diffusion = Diffusion(
@@ -81,21 +83,26 @@ def __init__(
8183
dynamic_threshold=diffusion_dynamic_threshold,
8284
)
8385

84-
def forward(self, x: Tensor) -> Tensor:
86+
def forward(self, x: Tensor, **kwargs) -> Tensor:
8587
if self.use_autoencoder:
8688
x = self.autoencoder_scale * self.autoencoder.encode(x) # type: ignore
87-
return self.diffusion(x)
89+
return self.diffusion(x, **kwargs)
8890

8991
def sample(
90-
self, noise: Tensor, num_steps: int, sigma_schedule: Schedule, sampler: Sampler
92+
self,
93+
noise: Tensor,
94+
num_steps: int,
95+
sigma_schedule: Schedule,
96+
sampler: Sampler,
97+
**kwargs
9198
) -> Tensor:
9299
diffusion_sampler = DiffusionSampler(
93100
diffusion=self.diffusion,
94101
sampler=sampler,
95102
sigma_schedule=sigma_schedule,
96103
num_steps=num_steps,
97104
)
98-
x = diffusion_sampler(noise)
105+
x = diffusion_sampler(noise, **kwargs)
99106

100107
if self.use_autoencoder:
101108
x = (1.0 / self.autoencoder_scale) * self.autoencoder.decode(x)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.16",
6+
version="0.0.17",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)