22from abc import ABC , abstractmethod
33import keras
44from keras import ops
5- from keras .saving import register_keras_serializable as serializable
65
6+ from bayesflow .utils .serialization import serialize , deserialize , serializable
77from bayesflow .types import Tensor , Shape
88import bayesflow as bf
99from bayesflow .networks import InferenceNetwork
1313 expand_right_as ,
1414 find_network ,
1515 jacobian_trace ,
16- serialize_value_or_type ,
17- deserialize_value_or_type ,
16+ layer_kwargs ,
1817 weighted_mean ,
1918 integrate ,
2019)
@@ -132,9 +131,9 @@ class LinearNoiseSchedule(NoiseSchedule):
132131 """
133132
134133 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
135- super ().__init__ (name = "linear_noise_schedule" , variance_type = "preserving" )
136- self ._log_snr_min = ops . convert_to_tensor ( min_log_snr )
137- self ._log_snr_max = ops . convert_to_tensor ( max_log_snr )
134+ super ().__init__ (name = "linear_noise_schedule" )
135+ self ._log_snr_min = min_log_snr
136+ self ._log_snr_max = max_log_snr
138137
139138 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
140139 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
@@ -182,9 +181,10 @@ class CosineNoiseSchedule(NoiseSchedule):
182181
183182 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 ):
184183 super ().__init__ (name = "cosine_noise_schedule" , variance_type = "preserving" )
185- self ._log_snr_min = ops .convert_to_tensor (min_log_snr )
186- self ._log_snr_max = ops .convert_to_tensor (max_log_snr )
187184 self ._s_shift_cosine = ops .convert_to_tensor (s_shift_cosine )
185+ self ._log_snr_min = min_log_snr
186+ self ._log_snr_max = max_log_snr
187+ self ._s_shift_cosine = s_shift_cosine
188188
189189 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
190190 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
@@ -229,12 +229,13 @@ class EDMNoiseSchedule(NoiseSchedule):
229229
230230 def __init__ (self , sigma_data : float = 0.5 , sigma_min : float = 0.002 , sigma_max : float = 80 ):
231231 super ().__init__ (name = "edm_noise_schedule" , variance_type = "exploding" )
232- self .sigma_data = ops .convert_to_tensor (sigma_data )
233- self .sigma_max = ops .convert_to_tensor (sigma_max )
234- self .sigma_min = ops .convert_to_tensor (sigma_min )
235- self .p_mean = ops .convert_to_tensor (- 1.2 )
236- self .p_std = ops .convert_to_tensor (1.2 )
237- self .rho = ops .convert_to_tensor (7 )
232+ super ().__init__ (name = "edm_noise_schedule" )
233+ self .sigma_data = sigma_data
234+ self .sigma_max = sigma_max
235+ self .sigma_min = sigma_min
236+ self .p_mean = - 1.2
237+ self .p_std = 1.2
238+ self .rho = 7
238239
239240 # convert EDM parameters to signal-to-noise ratio formulation
240241 self ._log_snr_min = - 2 * ops .log (sigma_max )
@@ -306,7 +307,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
306307 return ops .exp (- log_snr_t ) + 0.5 ** 2
307308
308309
309- @serializable ( package = "bayesflow.networks" )
310+ @serializable
310311class DiffusionModel (InferenceNetwork ):
311312 """Diffusion Model as described in this overview paper [1].
312313
@@ -401,18 +402,11 @@ def __init__(
401402 self .subnet = find_network (subnet , ** subnet_kwargs )
402403 self .output_projector = keras .layers .Dense (units = None , bias_initializer = "zeros" )
403404
404- # serialization: store all parameters necessary to call __init__
405- self .config = {
406- "integrate_kwargs" : self .integrate_kwargs ,
407- "subnet_kwargs" : subnet_kwargs ,
408- "noise_schedule" : self .noise_schedule ,
409- "prediction_type" : self .prediction_type ,
410- ** kwargs ,
411- }
412- self .config = serialize_value_or_type (self .config , "subnet" , subnet )
413-
414405 def build (self , xz_shape : Shape , conditions_shape : Shape = None ) -> None :
415- super ().build (xz_shape , conditions_shape = conditions_shape )
406+ if self .built :
407+ return
408+
409+ self .base_distribution .build (xz_shape )
416410
417411 self .output_projector .units = xz_shape [- 1 ]
418412 input_shape = list (xz_shape )
@@ -430,12 +424,19 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
430424
431425 def get_config (self ):
432426 base_config = super ().get_config ()
433- return base_config | self .config
427+ base_config = layer_kwargs (base_config )
428+
429+ config = {
430+ "subnet" : self .subnet ,
431+ "noise_schedule" : self .noise_schedule ,
432+ "integrate_kwargs" : self .integrate_kwargs ,
433+ "prediction_type" : self .prediction_type ,
434+ }
435+ return base_config | serialize (config )
434436
435437 @classmethod
436- def from_config (cls , config ):
437- config = deserialize_value_or_type (config , "subnet" )
438- return cls (** config )
438+ def from_config (cls , config , custom_objects = None ):
439+ return cls (** deserialize (config , custom_objects = custom_objects ))
439440
440441 def convert_prediction_to_x (
441442 self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor , clip_x : bool
@@ -515,7 +516,14 @@ def _forward(
515516 training : bool = False ,
516517 ** kwargs ,
517518 ) -> Tensor | tuple [Tensor , Tensor ]:
518- integrate_kwargs = self .integrate_kwargs | kwargs
519+ integrate_kwargs = (
520+ {
521+ "start_time" : self .noise_schedule ._t_min ,
522+ "stop_time" : self .noise_schedule ._t_max ,
523+ }
524+ | self .integrate_kwargs
525+ | kwargs
526+ )
519527 if density :
520528
521529 def deltas (time , xz ):
@@ -557,7 +565,14 @@ def _inverse(
557565 training : bool = False ,
558566 ** kwargs ,
559567 ) -> Tensor | tuple [Tensor , Tensor ]:
560- integrate_kwargs = self .integrate_kwargs | kwargs
568+ integrate_kwargs = (
569+ {
570+ "start_time" : self .noise_schedule ._t_max ,
571+ "stop_time" : self .noise_schedule ._t_min ,
572+ }
573+ | self .integrate_kwargs
574+ | kwargs
575+ )
561576 if density :
562577
563578 def deltas (time , xz ):
0 commit comments