22from abc import ABC , abstractmethod
33import keras
44from keras import ops
5- from keras .saving import register_keras_serializable as serializable
65
6+ from bayesflow .utils .serialization import serialize , deserialize , serializable
77from bayesflow .types import Tensor , Shape
88import bayesflow as bf
99from bayesflow .networks import InferenceNetwork
1313 expand_right_as ,
1414 find_network ,
1515 jacobian_trace ,
16- keras_kwargs ,
17- serialize_value_or_type ,
18- deserialize_value_or_type ,
16+ layer_kwargs ,
1917 weighted_mean ,
2018 integrate ,
2119)
@@ -145,8 +143,8 @@ class LinearNoiseSchedule(NoiseSchedule):
145143
146144 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
147145 super ().__init__ (name = "linear_noise_schedule" )
148- self ._log_snr_min = ops . convert_to_tensor ( min_log_snr )
149- self ._log_snr_max = ops . convert_to_tensor ( max_log_snr )
146+ self ._log_snr_min = min_log_snr
147+ self ._log_snr_max = max_log_snr
150148
151149 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
152150 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
@@ -192,11 +190,11 @@ class CosineNoiseSchedule(NoiseSchedule):
192190 [1] Diffusion models beat gans on image synthesis: Dhariwal and Nichol (2022)
193191 """
194192
195- def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 ):
193+ def __init__ (self , min_log_snr : float = - 15.0 , max_log_snr : float = 15.0 , s_shift_cosine : float = 0.0 ):
196194 super ().__init__ (name = "cosine_noise_schedule" )
197- self ._log_snr_min = ops . convert_to_tensor ( min_log_snr )
198- self ._log_snr_max = ops . convert_to_tensor ( max_log_snr )
199- self ._s_shift_cosine = ops . convert_to_tensor ( s_shift_cosine )
195+ self ._log_snr_min = min_log_snr
196+ self ._log_snr_max = max_log_snr
197+ self ._s_shift_cosine = s_shift_cosine
200198
201199 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
202200 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
@@ -210,7 +208,8 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
210208 def get_t_from_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
211209 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
212210 # SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
213- return 2 / math .pi * ops .arctan (ops .exp ((2 * self ._s_shift_cosine - log_snr_t ) / 2 ))
211+ print ("p" , log_snr_t )
212+ return 2.0 / math .pi * ops .arctan (ops .exp ((2.0 * self ._s_shift_cosine - log_snr_t ) / 2.0 ))
214213
215214 def derivative_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
216215 """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
@@ -241,12 +240,12 @@ class EDMNoiseSchedule(NoiseSchedule):
241240
242241 def __init__ (self , sigma_data : float = 0.5 , sigma_min : float = 0.002 , sigma_max : float = 80 ):
243242 super ().__init__ (name = "edm_noise_schedule" )
244- self .sigma_data = ops . convert_to_tensor ( sigma_data )
245- self .sigma_max = ops . convert_to_tensor ( sigma_max )
246- self .sigma_min = ops . convert_to_tensor ( sigma_min )
247- self .p_mean = ops . convert_to_tensor ( - 1.2 )
248- self .p_std = ops . convert_to_tensor ( 1.2 )
249- self .rho = ops . convert_to_tensor ( 7 )
243+ self .sigma_data = sigma_data
244+ self .sigma_max = sigma_max
245+ self .sigma_min = sigma_min
246+ self .p_mean = - 1.2
247+ self .p_std = 1.2
248+ self .rho = 7
250249
251250 # convert EDM parameters to signal-to-noise ratio formulation
252251 self ._log_snr_min = - 2 * ops .log (sigma_max )
@@ -336,7 +335,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
336335 return ops .exp (- log_snr_t ) + 0.5 ** 2
337336
338337
339- @serializable ( package = "bayesflow.networks" )
338+ @serializable
340339class DiffusionModel (InferenceNetwork ):
341340 """Diffusion Model as described in this overview paper [1].
342341
@@ -395,7 +394,7 @@ def __init__(
395394 Additional keyword arguments passed to the subnet and other components.
396395 """
397396
398- super ().__init__ (base_distribution = None , ** keras_kwargs ( kwargs ) )
397+ super ().__init__ (base_distribution = None , ** kwargs )
399398
400399 if isinstance (noise_schedule , str ):
401400 if noise_schedule == "linear" :
@@ -432,18 +431,11 @@ def __init__(
432431 self .subnet = find_network (subnet , ** subnet_kwargs )
433432 self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" )
434433
435- # serialization: store all parameters necessary to call __init__
436- self .config = {
437- "integrate_kwargs" : self .integrate_kwargs ,
438- "subnet_kwargs" : subnet_kwargs ,
439- "noise_schedule" : self .noise_schedule ,
440- "prediction_type" : self .prediction_type ,
441- ** kwargs ,
442- }
443- self .config = serialize_value_or_type (self .config , "subnet" , subnet )
444-
445434 def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
446- super ().build (xz_shape , conditions_shape = conditions_shape )
435+ if self .built :
436+ return
437+
438+ self .base_distribution .build (xz_shape )
447439
448440 self .output_projector .units = xz_shape [- 1 ]
449441 input_shape = list (xz_shape )
@@ -461,12 +453,19 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
461453
462454 def get_config (self ):
463455 base_config = super ().get_config ()
464- return base_config | self .config
456+ base_config = layer_kwargs (base_config )
457+
458+ config = {
459+ "subnet" : self .subnet ,
460+ "noise_schedule" : self .noise_schedule ,
461+ "integrate_kwargs" : self .integrate_kwargs ,
462+ "prediction_type" : self .prediction_type ,
463+ }
464+ return base_config | serialize (config )
465465
466466 @classmethod
467- def from_config (cls , config ):
468- config = deserialize_value_or_type (config , "subnet" )
469- return cls (** config )
467+ def from_config (cls , config , custom_objects = None ):
468+ return cls (** deserialize (config , custom_objects = custom_objects ))
470469
471470 def convert_prediction_to_x (
472471 self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor , clip_x : bool
@@ -546,7 +545,14 @@ def _forward(
546545 training : bool = False ,
547546 ** kwargs ,
548547 ) -> Tensor | tuple [Tensor , Tensor ]:
549- integrate_kwargs = self .integrate_kwargs | kwargs
548+ integrate_kwargs = (
549+ {
550+ "start_time" : self .noise_schedule ._t_min ,
551+ "stop_time" : self .noise_schedule ._t_max ,
552+ }
553+ | self .integrate_kwargs
554+ | kwargs
555+ )
550556 if density :
551557
552558 def deltas (time , xz ):
@@ -588,7 +594,14 @@ def _inverse(
588594 training : bool = False ,
589595 ** kwargs ,
590596 ) -> Tensor | tuple [Tensor , Tensor ]:
591- integrate_kwargs = self .integrate_kwargs | kwargs
597+ integrate_kwargs = (
598+ {
599+ "start_time" : self .noise_schedule ._t_max ,
600+ "stop_time" : self .noise_schedule ._t_min ,
601+ }
602+ | self .integrate_kwargs
603+ | kwargs
604+ )
592605 if density :
593606
594607 def deltas (time , xz ):
0 commit comments