@@ -244,6 +244,7 @@ def denoise_fn(
244244 x_noisy : Tensor ,
245245 sigmas : Optional [Tensor ] = None ,
246246 sigma : Optional [float ] = None ,
247+ ** kwargs ,
247248 ) -> Tensor :
248249 batch , device = x_noisy .shape [0 ], x_noisy .device
249250
@@ -257,7 +258,7 @@ def denoise_fn(
257258
258259 # Predict network output and add skip connection
259260 c_skip , c_out , c_in , c_noise = self .get_scale_weights (sigmas )
260- x_pred = self .net (c_in * x_noisy , c_noise )
261+ x_pred = self .net (c_in * x_noisy , c_noise , ** kwargs )
261262 x_denoised = c_skip * x_noisy + c_out * x_pred
262263
263264 # Dynamic thresholding
@@ -278,7 +279,7 @@ def loss_weight(self, sigmas: Tensor) -> Tensor:
278279 # Computes weight depending on data distribution
279280 return (sigmas ** 2 + self .sigma_data ** 2 ) * (sigmas * self .sigma_data ) ** - 2
280281
281- def forward (self , x : Tensor , noise : Tensor = None ) -> Tensor :
282+ def forward (self , x : Tensor , noise : Tensor = None , ** kwargs ) -> Tensor :
282283 batch , device = x .shape [0 ], x .device
283284
284285 # Sample amount of noise to add for each batch element
@@ -290,7 +291,7 @@ def forward(self, x: Tensor, noise: Tensor = None) -> Tensor:
290291 x_noisy = x + sigmas_padded * noise
291292
292293 # Compute denoised values
293- x_denoised = self .denoise_fn (x_noisy , sigmas = sigmas )
294+ x_denoised = self .denoise_fn (x_noisy , sigmas = sigmas , ** kwargs )
294295
295296 # Compute weighted loss
296297 losses = F .mse_loss (x_denoised , x , reduction = "none" )
@@ -317,14 +318,18 @@ def __init__(
317318 self .num_steps = num_steps
318319
319320 @torch .no_grad ()
320- def forward (self , noise : Tensor , num_steps : Optional [int ] = None ) -> Tensor :
321+ def forward (
322+ self , noise : Tensor , num_steps : Optional [int ] = None , ** kwargs
323+ ) -> Tensor :
321324 device = noise .device
322325 num_steps = default (num_steps , self .num_steps ) # type: ignore
323326 assert exists (num_steps ), "Parameter `num_steps` must be provided"
324327 # Compute sigmas using schedule
325328 sigmas = self .sigma_schedule (num_steps , device )
329+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
330+ fn = lambda * a , ** ka : self .denoise_fn (* a , ** {** ka , ** kwargs }) # noqa
326331 # Sample using sampler
327- x = self .sampler (noise , fn = self . denoise_fn , sigmas = sigmas , num_steps = num_steps )
332+ x = self .sampler (noise , fn = fn , sigmas = sigmas , num_steps = num_steps )
328333 x = x .clamp (- 1.0 , 1.0 )
329334 return x
330335
0 commit comments