@@ -147,66 +147,12 @@ def validate(self):
147147 raise ValueError ("dt/t log_snr(1) must be finite." )
148148
149149
150- @serializable
151- class LinearNoiseSchedule (NoiseSchedule ):
152- """Linear noise schedule for diffusion models.
153-
154- The linear noise schedule with likelihood weighting is based on [1].
155-
156- [1] Maximum Likelihood Training of Score-Based Diffusion Models: Song et al. (2021)
157- """
158-
159- def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
160- super ().__init__ (name = "linear_noise_schedule" , variance_type = "preserving" , weighting = "likelihood_weighting" )
161- self .log_snr_min = min_log_snr
162- self .log_snr_max = max_log_snr
163-
164- self ._t_min = self .get_t_from_log_snr (log_snr_t = self .log_snr_max , training = True )
165- self ._t_max = self .get_t_from_log_snr (log_snr_t = self .log_snr_min , training = True )
166-
167- def _truncated_t (self , t : Tensor ) -> Tensor :
168- return self ._t_min + (self ._t_max - self ._t_min ) * t
169-
170- def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
171- """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
172- t_trunc = self ._truncated_t (t )
173- # SNR = -log(exp(t^2) - 1)
174- # equivalent, but more stable: -t^2 - log(1 - exp(-t^2))
175- return - ops .square (t_trunc ) - ops .log (1 - ops .exp (- ops .square (t_trunc )))
176-
177- def get_t_from_log_snr (self , log_snr_t : Union [float , Tensor ], training : bool ) -> Tensor :
178- """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
179- # SNR = -log(exp(t^2) - 1) => t = sqrt(log(1 + exp(-snr)))
180- return ops .sqrt (ops .softplus (- log_snr_t ))
181-
182- def derivative_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
183- """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
184- t = self .get_t_from_log_snr (log_snr_t = log_snr_t , training = training )
185-
186- # Compute the truncated time t_trunc
187- t_trunc = self ._truncated_t (t )
188- dsnr_dx = - 2 * t_trunc / (1 - ops .exp (- (t_trunc ** 2 )))
189-
190- # Using the chain rule on f(t) = log(1 + e^(-snr(t))):
191- # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
192- dsnr_dt = dsnr_dx * (self ._t_max - self ._t_min )
193- factor = ops .exp (- log_snr_t ) / (1 + ops .exp (- log_snr_t ))
194- return - factor * dsnr_dt
195-
196- def get_config (self ):
197- return dict (min_log_snr = self .log_snr_min , max_log_snr = self .log_snr_max )
198-
199- @classmethod
200- def from_config (cls , config , custom_objects = None ):
201- return cls (** deserialize (config , custom_objects = custom_objects ))
202-
203-
204150@serializable
205151class CosineNoiseSchedule (NoiseSchedule ):
206152 """Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
207153 For images, use s_shift_cosine = log(base_resolution / d), where d is the used resolution of the image.
208154
209- [1] Diffusion models beat gans on image synthesis : Dhariwal and Nichol (2022)
155+ [1] Diffusion Models Beat GANs on Image Synthesis : Dhariwal and Nichol (2022)
210156 """
211157
212158 def __init__ (
@@ -371,6 +317,7 @@ class DiffusionModel(InferenceNetwork):
371317
372318 def __init__ (
373319 self ,
320+ * ,
374321 subnet : str | type = "mlp" ,
375322 integrate_kwargs : dict [str , any ] = None ,
376323 subnet_kwargs : dict [str , any ] = None ,
@@ -384,8 +331,8 @@ def __init__(
384331 This model learns a transformation from a Gaussian latent distribution to a target distribution using a
385332 specified subnet type, which can be an MLP or a custom network.
386333
387- The integration steps can be customized with additional parameters available in the respective
388- configuration dictionary.
334+ The integration can be customized with additional parameters available in the integrate_kwargs
335+ configuration dictionary. Different noise schedules and prediction types are available.
389336
390337 Parameters
391338 ----------
@@ -397,7 +344,7 @@ def __init__(
397344 subnet_kwargs : dict[str, any], optional
398345 Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
399346 noise_schedule : str or NoiseSchedule, optional
400- The noise schedule used for the diffusion process. Can be "linear", " cosine", or "edm".
347+ The noise schedule used for the diffusion process. Can be "cosine" or "edm".
401348 Default is "edm".
402349 prediction_type: str, optional
403350 The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).
@@ -408,9 +355,7 @@ def __init__(
408355 super ().__init__ (base_distribution = "normal" , ** kwargs )
409356
410357 if isinstance (noise_schedule , str ):
411- if noise_schedule == "linear" :
412- noise_schedule = LinearNoiseSchedule ()
413- elif noise_schedule == "cosine" :
358+ if noise_schedule == "cosine" :
414359 noise_schedule = CosineNoiseSchedule ()
415360 elif noise_schedule == "edm" :
416361 noise_schedule = EDMNoiseSchedule ()
@@ -435,10 +380,12 @@ def __init__(
435380 )
436381
437382 # clipping of prediction (after it was transformed to x-prediction)
438- self ._clip_min = - 5.0
439- self ._clip_max = 5.0
383+ # keeping this private for now, as it is usually not required in SBI and somewhat dangerous
384+ self ._clip_x = kwargs .get ("clip_x" , None )
385+ if self ._clip_x is not None :
386+ if len (self ._clip_x ) != 2 or self ._clip_x [0 ] > self ._clip_x [1 ]:
387+ raise ValueError ("'clip_x' has to be a list or tuple with the values [x_min, x_max]" )
440388
441- # latent distribution (not configurable)
442389 self .integrate_kwargs = self .INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
443390 self .seed_generator = keras .random .SeedGenerator ()
444391
@@ -456,6 +403,8 @@ def __init__(
456403 self .subnet = find_network (subnet , ** subnet_kwargs )
457404 self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" )
458405
406+ self ._kwargs = kwargs
407+
459408 def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
460409 if self .built :
461410 return
@@ -480,7 +429,7 @@ def get_config(self):
480429 base_config = super ().get_config ()
481430 base_config = layer_kwargs (base_config )
482431
483- config = {
432+ config = self . _kwargs | {
484433 "subnet" : self .subnet ,
485434 "noise_schedule" : self .noise_schedule ,
486435 "integrate_kwargs" : self .integrate_kwargs ,
@@ -494,7 +443,7 @@ def from_config(cls, config, custom_objects=None):
494443 return cls (** deserialize (config , custom_objects = custom_objects ))
495444
496445 def convert_prediction_to_x (
497- self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor , clip_x : bool
446+ self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor
498447 ) -> Tensor :
499448 """Convert the prediction of the neural network to the x space."""
500449 if self ._prediction_type == "velocity" :
@@ -504,7 +453,7 @@ def convert_prediction_to_x(
504453 # convert noise prediction into x
505454 x = (z - sigma_t * pred ) / alpha_t
506455 elif self ._prediction_type == "F" : # EDM
507- sigma_data = self .noise_schedule .sigma_data
456+ sigma_data = self .noise_schedule .sigma_data if hasattr ( self . noise_schedule , "sigma_data" ) else 1.0
508457 x1 = (sigma_data ** 2 * alpha_t ) / (ops .exp (- log_snr_t ) + sigma_data ** 2 )
509458 x2 = ops .exp (- log_snr_t / 2 ) * sigma_data / ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 )
510459 x = x1 * z + x2 * pred
@@ -513,8 +462,8 @@ def convert_prediction_to_x(
513462 else : # "score"
514463 x = (z + sigma_t ** 2 * pred ) / alpha_t
515464
516- if clip_x :
517- x = ops .clip (x , self ._clip_min , self ._clip_max )
465+ if self . _clip_x is not None :
466+ x = ops .clip (x , self ._clip_x [ 0 ] , self ._clip_x [ 1 ] )
518467 return x
519468
520469 def velocity (
@@ -524,7 +473,6 @@ def velocity(
524473 stochastic_solver : bool ,
525474 conditions : Tensor = None ,
526475 training : bool = False ,
527- clip_x : bool = False ,
528476 ) -> Tensor :
529477 # calculate the current noise level and transform into correct shape
530478 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
@@ -537,9 +485,7 @@ def velocity(
537485 xtc = ops .concatenate ([xz , self ._transform_log_snr (log_snr_t ), conditions ], axis = - 1 )
538486 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
539487
540- x_pred = self .convert_prediction_to_x (
541- pred = pred , z = xz , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t , clip_x = clip_x
542- )
488+ x_pred = self .convert_prediction_to_x (pred = pred , z = xz , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t )
543489 # convert x to score
544490 score = (alpha_t * x_pred - xz ) / ops .square (sigma_t )
545491
@@ -725,6 +671,7 @@ def compute_metrics(
725671 stage : str = "training" ,
726672 ) -> dict [str , Tensor ]:
727673 training = stage == "training"
674+ # use same noise schedule for training and validation to keep them comparable
728675 noise_schedule_training_stage = stage == "training" or stage == "validation"
729676 if not self .built :
730677 xz_shape = ops .shape (x )
@@ -760,7 +707,7 @@ def compute_metrics(
760707 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
761708
762709 x_pred = self .convert_prediction_to_x (
763- pred = pred , z = diffused_x , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t , clip_x = False
710+ pred = pred , z = diffused_x , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t
764711 )
765712
766713 # Calculate loss
@@ -775,7 +722,7 @@ def compute_metrics(
775722 loss = weights_for_snr * ops .mean ((velocity_pred - v_t ) ** 2 , axis = - 1 )
776723 elif self ._loss_type == "F" :
777724 # convert x to F prediction
778- sigma_data = self .noise_schedule .sigma_data
725+ sigma_data = self .noise_schedule .sigma_data if hasattr ( self . noise_schedule , "sigma_data" ) else 1.0
779726 x1 = ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 ) / (ops .exp (- log_snr_t / 2 ) * sigma_data )
780727 x2 = (sigma_data * alpha_t ) / (ops .exp (- log_snr_t / 2 ) * ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 ))
781728 f_pred = x1 * x_pred - x2 * diffused_x
0 commit comments