Skip to content

Commit b1b859e

Browse files
feat: add clamp option
1 parent 24ff00f commit b1b859e

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

audio_diffusion_pytorch/diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,12 +555,14 @@ def __init__(
555555
sampler: Sampler,
556556
sigma_schedule: Schedule,
557557
num_steps: Optional[int] = None,
558+
clamp: bool = True,
558559
):
559560
super().__init__()
560561
self.denoise_fn = diffusion.denoise_fn
561562
self.sampler = sampler
562563
self.sigma_schedule = sigma_schedule
563564
self.num_steps = num_steps
565+
self.clamp = clamp
564566

565567
# Check sampler is compatible with diffusion type
566568
sampler_class = sampler.__class__.__name__
@@ -581,7 +583,7 @@ def forward(
581583
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
582584
# Sample using sampler
583585
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
584-
x = x.clamp(-1.0, 1.0)
586+
x = x.clamp(-1.0, 1.0) if self.clamp else x
585587
return x
586588

587589

audio_diffusion_pytorch/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ def sample(
5757
num_steps: int,
5858
sigma_schedule: Schedule,
5959
sampler: Sampler,
60+
clamp: bool,
6061
**kwargs,
6162
) -> Tensor:
6263
diffusion_sampler = DiffusionSampler(
6364
diffusion=self.diffusion,
6465
sampler=sampler,
6566
sigma_schedule=sigma_schedule,
6667
num_steps=num_steps,
68+
clamp=clamp,
6769
)
6870
return diffusion_sampler(noise, **kwargs)
6971

@@ -251,10 +253,7 @@ def get_default_model_kwargs():
251253

252254

253255
def get_default_sampling_kwargs():
254-
return dict(
255-
sigma_schedule=LinearSchedule(),
256-
sampler=VSampler(),
257-
)
256+
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
258257

259258

260259
class AudioDiffusionModel(Model1d):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.74",
6+
version="0.0.75",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)