66from bayesflow .types import Tensor , Shape
77import bayesflow as bf
88from bayesflow .networks import InferenceNetwork
9+ import math
910
1011from bayesflow .utils import (
1112 expand_right_as ,
2122
2223@serializable (package = "bayesflow.networks" )
2324class DiffusionModel (InferenceNetwork ):
24- """Diffusion Model as described as Elucidated Diffusion Model in [1].
25+ """Diffusion Model as described in this overview paper [1].
26+
27+ [1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data
28+ Augmentation: Kingma et al. (2023)
29+ [2] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021)
30+ [3] Elucidating the Design Space of Diffusion-Based Generative Models: arXiv:2206.00364
2531
26- [1] Elucidating the Design Space of Diffusion-Based Generative Models: arXiv:2206.00364
2732 """
2833
2934 MLP_DEFAULT_CONFIG = {
@@ -74,16 +79,34 @@ def __init__(
7479
7580 super ().__init__ (base_distribution = None , ** keras_kwargs (kwargs ))
7681
82+ # todo: clean up these configurations
83+ # EDM hyper-parameters
7784 # internal tunable parameters not intended to be modified by the average user
7885 self .max_sigma = kwargs .get ("max_sigma" , 80.0 )
7986 self .min_sigma = kwargs .get ("min_sigma" , 1e-4 )
8087 self .rho = kwargs .get ("rho" , 7 )
8188 # hyper-parameters for sampling the noise level
8289 self .p_mean = kwargs .get ("p_mean" , - 1.2 )
8390 self .p_std = kwargs .get ("p_std" , 1.2 )
91+ self ._noise_schedule = kwargs .get ("noise_schedule" , "EDM" )
92+
93+ # general hyper-parameters
94+ self ._train_time = kwargs .get ("train_time" , "continuous" )
95+ self ._timesteps = kwargs .get ("timesteps" , None )
96+ if self ._train_time == "discrete" :
97+ if not isinstance (self ._timesteps , int ):
98+ raise ValueError ('timesteps must be defined, if "discrete" training time is set' )
99+ self ._loss_type = kwargs .get ("loss_type" , "eps" )
100+ self ._weighting_function = kwargs .get ("weighting_function" , None )
101+ self ._log_snr_min = kwargs .get ("log_snr_min" , - 15 )
102+ self ._log_snr_max = kwargs .get ("log_snr_max" , 15 )
103+ self ._t_min = self ._get_t_from_log_snr (log_snr_t = self ._log_snr_max )
104+ self ._t_max = self ._get_t_from_log_snr (log_snr_t = self ._log_snr_min )
105+ self ._s_shift_cosine = kwargs .get ("s_shift_cosine" , 0.0 )
84106
85107 # latent distribution (not configurable)
86108 self .base_distribution = bf .distributions .DiagonalNormal (mean = 0.0 , std = self .max_sigma )
109+
87110 self .integrate_kwargs = self .INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
88111
89112 self .sigma_data = sigma_data
@@ -142,51 +165,62 @@ def _c_in_fn(self, sigma):
142165 return 1.0 / ops .sqrt (sigma ** 2 + self .sigma_data ** 2 )
143166
144167 def _c_noise_fn (self , sigma ):
145- return 0.25 * ops .log (sigma )
146-
147- def _denoiser_fn (
148- self ,
149- xz : Tensor ,
150- sigma : Tensor ,
151- conditions : Tensor = None ,
152- training : bool = False ,
153- ):
154- # calculate output of the network
155- c_in = self ._c_in_fn (sigma )
156- c_noise = self ._c_noise_fn (sigma )
157- xz_pre = c_in * xz
158- if conditions is None :
159- xtc = keras .ops .concatenate ([xz_pre , c_noise ], axis = - 1 )
160- else :
161- xtc = keras .ops .concatenate ([xz_pre , c_noise , conditions ], axis = - 1 )
162- out = self .output_projector (self .subnet (xtc , training = training ), training = training )
163- return self ._c_skip_fn (sigma ) * xz + self ._c_out_fn (sigma ) * out
168+ return 0.25 * ops .log (sigma ) # this is the snr times a constant
164169
165170 def velocity (
166171 self ,
167172 xz : Tensor ,
168- sigma : float | Tensor ,
173+ time : float | Tensor ,
169174 conditions : Tensor = None ,
170175 training : bool = False ,
176+ clip_x : bool = True ,
171177 ) -> Tensor :
172- # transform sigma vector into correct shape
173- sigma = keras .ops .convert_to_tensor (sigma , dtype = keras .ops .dtype (xz ))
174- sigma = expand_right_as (sigma , xz )
175- sigma = keras .ops .broadcast_to (sigma , keras .ops .shape (xz )[:- 1 ] + (1 ,))
178+ # calculate the current noise level and transform into correct shape
179+ log_snr_t = expand_right_as (self ._get_log_snr (t = time ), xz )
180+ alpha_t , sigma_t = self ._get_alpha_sigma (log_snr_t = log_snr_t )
176181
177- d = self ._denoiser_fn (xz , sigma , conditions , training = training )
178- return (xz - d ) / sigma
182+ if self ._noise_schedule == "EDM" :
183+ # scale the input
184+ xz = alpha_t * xz
185+
186+ if conditions is None :
187+ xtc = keras .ops .concatenate ([xz , log_snr_t ], axis = - 1 )
188+ else :
189+ xtc = keras .ops .concatenate ([xz , log_snr_t , conditions ], axis = - 1 )
190+ pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
191+
192+ if self ._noise_schedule == "EDM" :
193+ # scale the output
194+ s = ops .exp (- 1 / 2 * log_snr_t )
195+ pred_scaled = self ._c_skip_fn (s ) * xz + self ._c_out_fn (s ) * pred
196+ out = (xz - pred_scaled ) / s
197+ else :
198+ # first convert prediction to x-prediction
199+ if self ._loss_type == "eps" :
200+ x_pred = (xz - sigma_t * pred ) / alpha_t
201+ else : # self._loss_type == 'v':
202+ x_pred = alpha_t * xz - sigma_t * pred
203+
204+ # clip x if necessary
205+ if clip_x :
206+ x_pred = ops .clip (x_pred , - 5 , 5 )
207+ # convert x to score
208+ score = (alpha_t * x_pred - xz ) / ops .square (sigma_t )
209+ # compute velocity for the ODE depending on the noise schedule
210+ f , g = self ._get_drift_diffusion (log_snr_t = log_snr_t , x = xz )
211+ out = f - 0.5 * ops .square (g ) * score
212+ return out
179213
180214 def _velocity_trace (
181215 self ,
182216 xz : Tensor ,
183- sigma : Tensor ,
217+ time : Tensor ,
184218 conditions : Tensor = None ,
185219 max_steps : int = None ,
186220 training : bool = False ,
187221 ) -> (Tensor , Tensor ):
188222 def f (x ):
189- return self .velocity (x , sigma = sigma , conditions = conditions , training = training )
223+ return self .velocity (x , time = time , conditions = conditions , training = training )
190224
191225 v , trace = jacobian_trace (f , xz , max_steps = max_steps , seed = self .seed_generator , return_output = True )
192226
@@ -207,7 +241,7 @@ def _forward(
207241 if density :
208242
209243 def deltas (time , xz ):
210- v , trace = self ._velocity_trace (xz , sigma = time , conditions = conditions , training = training )
244+ v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
211245 return {"xz" : v , "trace" : trace }
212246
213247 state = {
@@ -226,7 +260,7 @@ def deltas(time, xz):
226260 return z , log_density
227261
228262 def deltas (time , xz ):
229- return {"xz" : self .velocity (xz , sigma = time , conditions = conditions , training = training )}
263+ return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
230264
231265 state = {"xz" : x }
232266 state = integrate (
@@ -256,7 +290,7 @@ def _inverse(
256290 if density :
257291
258292 def deltas (time , xz ):
259- v , trace = self ._velocity_trace (xz , sigma = time , conditions = conditions , training = training )
293+ v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
260294 return {"xz" : v , "trace" : trace }
261295
262296 state = {
@@ -271,7 +305,7 @@ def deltas(time, xz):
271305 return x , log_density
272306
273307 def deltas (time , xz ):
274- return {"xz" : self .velocity (xz , sigma = time , conditions = conditions , training = training )}
308+ return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
275309
276310 state = {"xz" : z }
277311 state = integrate (
@@ -284,6 +318,120 @@ def deltas(time, xz):
284318
285319 return x
286320
321+ def _get_drift_diffusion (self , log_snr_t , x = None ): # t is not truncated
322+ """
323+ Compute d/dt log(1 + e^(-snr(t))) for the truncated schedules.
324+ """
325+ t = self ._get_t_from_log_snr (log_snr_t = log_snr_t )
326+ # Compute the truncated time t_trunc
327+ t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
328+
329+ # Compute d/dx snr(x) based on the noise schedule
330+ if self ._noise_schedule == "linear" :
331+ # d/dx snr(x) = - 2*x*exp(x^2) / (exp(x^2) - 1)
332+ dsnr_dx = - (2 * t_trunc * ops .exp (t_trunc ** 2 )) / (ops .exp (t_trunc ** 2 ) - 1 )
333+ elif self ._noise_schedule == "cosine" :
334+ # d/dx snr(x) = -2*pi/sin(pi*x)
335+ dsnr_dx = - (2 * math .pi ) / ops .sin (math .pi * t_trunc )
336+ elif self ._noise_schedule == "flow_matching" :
337+ # d/dx snr(x) = -2/(x*(1-x))
338+ dsnr_dx = - 2 / (t_trunc * (1 - t_trunc ))
339+ else :
340+ raise ValueError ("Invalid 'noise_schedule'." )
341+
342+ # Chain rule: d/dt snr(t) = d/dx snr(x) * (t_max - t_min)
343+ dsnr_dt = dsnr_dx * (self ._t_max - self ._t_min )
344+
345+ # Using the chain rule on f(t) = log(1 + e^(-snr(t))):
346+ # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
347+ factor = ops .exp (- log_snr_t ) / (1 + ops .exp (- log_snr_t ))
348+
349+ beta_t = - factor * dsnr_dt
350+ g = ops .sqrt (beta_t ) # diffusion term
351+ if x is None :
352+ return g
353+ f = - 0.5 * beta_t * x # drift term
354+ return f , g
355+
356+ def _get_log_snr (self , t : Tensor ) -> Tensor :
357+ """get the log signal-to-noise ratio (lambda) for a given diffusion time"""
358+ if self ._noise_schedule == "EDM" :
359+ # EDM defines tilde sigma ~ N(p_mean, p_std^2)
360+ # tilde sigma^2 = exp(-lambda), hence lambda = -2 * log(sigma)
361+ # sample noise
362+ log_sigma_tilde = self .p_mean + self .p_std * keras .random .normal (
363+ ops .shape (t ), dtype = ops .dtype (t ), seed = self .seed_generator
364+ )
365+ # calculate the log signal-to-noise ratio
366+ log_snr_t = - 2 * log_sigma_tilde
367+ return log_snr_t
368+
369+ t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
370+ if self ._noise_schedule == "linear" :
371+ log_snr_t = - ops .log (ops .exp (ops .square (t_trunc )) - 1 )
372+ elif self ._noise_schedule == "cosine" : # this is usually used with variance_preserving
373+ log_snr_t = - 2 * ops .log (ops .tan (math .pi * t_trunc / 2 )) + 2 * self ._s_shift_cosine
374+ elif self ._noise_schedule == "flow_matching" : # this usually used with sub_variance_preserving
375+ log_snr_t = 2 * ops .log ((1 - t_trunc ) / t_trunc )
376+ else :
377+ raise ValueError ("Unknown noise schedule: {}" .format (self ._noise_schedule ))
378+ return log_snr_t
379+
380+ def _get_t_from_log_snr (self , log_snr_t ) -> Tensor :
381+ # Invert the noise scheduling to recover t (not truncated)
382+ if self ._noise_schedule == "linear" :
383+ # SNR = -log(exp(t^2) - 1)
384+ # => t = sqrt(log(1 + exp(-snr)))
385+ t = ops .sqrt (ops .log (1 + ops .exp (- log_snr_t )))
386+ elif self ._noise_schedule == "cosine" :
387+ # SNR = -2 * log(tan(pi*t/2))
388+ # => t = 2/pi * arctan(exp(-snr/2))
389+ t = 2 / math .pi * ops .arctan (ops .exp ((2 * self ._s_shift_cosine - log_snr_t ) / 2 ))
390+ elif self ._noise_schedule == "flow_matching" :
391+ # SNR = 2 * log((1-t)/t)
392+ # => t = 1 / (1 + exp(snr/2))
393+ t = 1 / (1 + ops .exp (log_snr_t / 2 ))
394+ elif self ._noise_schedule == "EDM" :
395+ raise NotImplementedError
396+ else :
397+ raise ValueError ("Unknown noise schedule: {}" .format (self ._noise_schedule ))
398+ return t
399+
400+ def _get_alpha_sigma (self , log_snr_t : Tensor ) -> tuple [Tensor , Tensor ]:
401+ if self ._noise_schedule == "EDM" :
402+ # EDM: noisy_x = c_in * (x + s * e) = c_in * x + c_in * s * e
403+ # s^2 = exp(-lambda)
404+ s = ops .exp (- 1 / 2 * log_snr_t )
405+ c_in = self ._c_in_fn (s )
406+
407+ # alpha = c_in(s), sigma = c_in * s
408+ alpha_t = c_in
409+ sigma_t = c_in * s
410+ else :
411+ # variance preserving noise schedules
412+ alpha_t = keras .ops .sqrt (keras .ops .sigmoid (log_snr_t ))
413+ sigma_t = keras .ops .sqrt (keras .ops .sigmoid (- log_snr_t ))
414+ return alpha_t , sigma_t
415+
416+ def _get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
417+ if self ._noise_schedule == "EDM" :
418+ # EDM: weights are constructed elsewhere
419+ weights = ops .ones_like (log_snr_t )
420+ return weights
421+
422+ if self ._weighting_function == "likelihood_weighting" : # based on Song et al. (2021)
423+ g_t = self ._get_drift_diffusion (log_snr_t = log_snr_t )
424+ sigma_t = self ._get_alpha_sigma (log_snr_t = log_snr_t )[1 ]
425+ weights = ops .square (g_t / sigma_t )
426+ elif self ._weighting_function == "sigmoid" : # based on Kingma et al. (2023)
427+ weights = ops .sigmoid (- log_snr_t / 2 )
428+ elif self ._weighting_function == "min-snr" : # based on Hang et al. (2023)
429+ gamma = 5
430+ weights = 1 / ops .cosh (log_snr_t / 2 ) * ops .minimum (ops .ones_like (log_snr_t ), gamma * ops .exp (- log_snr_t ))
431+ else :
432+ weights = ops .ones_like (log_snr_t )
433+ return weights
434+
287435 def compute_metrics (
288436 self ,
289437 x : Tensor | Sequence [Tensor , ...],
@@ -297,36 +445,51 @@ def compute_metrics(
297445 conditions_shape = None if conditions is None else keras .ops .shape (conditions )
298446 self .build (xz_shape , conditions_shape )
299447
300- # sample log-noise level
301- log_sigma = self .p_mean + self .p_std * keras .random .normal (
302- ops .shape (x )[:1 ], dtype = ops .dtype (x ), seed = self .seed_generator
303- )
304- # noise level with shape (batch_size, 1)
305- sigma = ops .exp (log_sigma )[:, None ]
448+ # sample training diffusion time
449+ if self ._train_time == "continuous" :
450+ t = keras .random .uniform ((keras .ops .shape (x )[0 ],))
451+ elif self ._train_time == "discrete" :
452+ i = keras .random .randint ((keras .ops .shape (x )[0 ],), minval = 0 , maxval = self ._timesteps )
453+ t = keras .ops .cast (i , keras .ops .dtype (x )) / keras .ops .cast (self ._timesteps , keras .ops .dtype (x ))
454+ else :
455+ raise NotImplementedError (f"Training time { self ._train_time } not implemented" )
456+
457+ # calculate the noise level
458+ log_snr_t = expand_right_as (self ._get_log_snr (t ), x )
459+ alpha_t , sigma_t = self ._get_alpha_sigma (log_snr_t = log_snr_t )
306460
307461 # generate noise vector
308- z = sigma * keras .random .normal (ops .shape (x ), dtype = ops .dtype (x ), seed = self .seed_generator )
462+ eps_t = keras .random .normal (ops .shape (x ), dtype = ops .dtype (x ), seed = self .seed_generator )
309463
310- # calculate preconditioning
311- c_skip = self ._c_skip_fn (sigma )
312- c_out = self ._c_out_fn (sigma )
313- c_in = self ._c_in_fn (sigma )
314- c_noise = self ._c_noise_fn (sigma )
315- xz_pre = c_in * (x + z )
464+ # diffuse x
465+ diffused_x = alpha_t * x + sigma_t * eps_t
316466
317467 # calculate output of the network
318468 if conditions is None :
319- xtc = keras .ops .concatenate ([xz_pre , c_noise ], axis = - 1 )
469+ xtc = keras .ops .concatenate ([diffused_x , log_snr_t ], axis = - 1 )
320470 else :
321- xtc = keras .ops .concatenate ([xz_pre , c_noise , conditions ], axis = - 1 )
471+ xtc = keras .ops .concatenate ([diffused_x , log_snr_t , conditions ], axis = - 1 )
322472
323473 out = self .output_projector (self .subnet (xtc , training = training ), training = training )
324474
325- # Calculate loss:
326- lam = 1 / c_out [:, 0 ] ** 2
327- effective_weight = lam * c_out [:, 0 ] ** 2
328- unweighted_loss = ops .mean ((out - 1 / c_out * (x - c_skip * (x + z ))) ** 2 , axis = - 1 )
329- loss = effective_weight * unweighted_loss
475+ # Calculate loss
476+ weights_for_snr = self ._get_weights_for_snr (log_snr_t = log_snr_t )
477+ if self ._loss_type == "eps" :
478+ loss = weights_for_snr * ops .mean ((out - eps_t ) ** 2 , axis = - 1 )
479+ elif self ._loss_type == "v" :
480+ v_t = alpha_t * eps_t - sigma_t * x
481+ loss = weights_for_snr * ops .mean ((out - v_t ) ** 2 , axis = - 1 )
482+ elif self ._loss_type == "EDM" :
483+ s = ops .exp (- 1 / 2 * log_snr_t )
484+ c_skip = self ._c_skip_fn (s )
485+ c_out = self ._c_out_fn (s )
486+ lam = 1 / c_out [:, 0 ] ** 2
487+ effective_weight = lam * c_out [:, 0 ] ** 2
488+ unweighted_loss = ops .mean ((out - 1 / c_out * (x - c_skip * (x + s + eps_t ))) ** 2 , axis = - 1 )
489+ loss = effective_weight * unweighted_loss
490+ else :
491+ raise ValueError (f"Unknown loss type: { self ._loss_type } " )
492+
330493 loss = weighted_mean (loss , sample_weight )
331494
332495 base_metrics = super ().compute_metrics (x , conditions , sample_weight , stage )
0 commit comments