Skip to content

Commit 4c48dce

Browse files
feat: forward args to denoise function
1 parent 365d0cd commit 4c48dce

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

audio_diffusion_pytorch/diffusion.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def denoise_fn(
244244
x_noisy: Tensor,
245245
sigmas: Optional[Tensor] = None,
246246
sigma: Optional[float] = None,
247+
**kwargs,
247248
) -> Tensor:
248249
batch, device = x_noisy.shape[0], x_noisy.device
249250

@@ -257,7 +258,7 @@ def denoise_fn(
257258

258259
# Predict network output and add skip connection
259260
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
260-
x_pred = self.net(c_in * x_noisy, c_noise)
261+
x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
261262
x_denoised = c_skip * x_noisy + c_out * x_pred
262263

263264
# Dynamic thresholding
@@ -278,7 +279,7 @@ def loss_weight(self, sigmas: Tensor) -> Tensor:
278279
# Computes weight depending on data distribution
279280
return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
280281

281-
def forward(self, x: Tensor, noise: Tensor = None) -> Tensor:
282+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
282283
batch, device = x.shape[0], x.device
283284

284285
# Sample amount of noise to add for each batch element
@@ -290,7 +291,7 @@ def forward(self, x: Tensor, noise: Tensor = None) -> Tensor:
290291
x_noisy = x + sigmas_padded * noise
291292

292293
# Compute denoised values
293-
x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas)
294+
x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
294295

295296
# Compute weighted loss
296297
losses = F.mse_loss(x_denoised, x, reduction="none")
@@ -317,14 +318,18 @@ def __init__(
317318
self.num_steps = num_steps
318319

319320
@torch.no_grad()
320-
def forward(self, noise: Tensor, num_steps: Optional[int] = None) -> Tensor:
321+
def forward(
322+
self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
323+
) -> Tensor:
321324
device = noise.device
322325
num_steps = default(num_steps, self.num_steps) # type: ignore
323326
assert exists(num_steps), "Parameter `num_steps` must be provided"
324327
# Compute sigmas using schedule
325328
sigmas = self.sigma_schedule(num_steps, device)
329+
# Append additional kwargs to denoise function (used e.g. for conditional unet)
330+
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
326331
# Sample using sampler
327-
x = self.sampler(noise, fn=self.denoise_fn, sigmas=sigmas, num_steps=num_steps)
332+
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
328333
x = x.clamp(-1.0, 1.0)
329334
return x
330335

0 commit comments

Comments
 (0)