1+ from math import pi
2+
13import keras
24from keras import ops
35
4- import numpy as np
5-
66from bayesflow .networks import MLP
77from bayesflow .types import Tensor
88from bayesflow .utils import (
9+ logging ,
910 jvp ,
1011 concatenate_valid ,
1112 find_network ,
1213 expand_right_as ,
1314 expand_right_to ,
14- model_kwargs ,
15+ layer_kwargs ,
1516)
1617from bayesflow .utils .serialization import deserialize , serializable , serialize
1718
18-
1919from bayesflow .networks import InferenceNetwork
2020from bayesflow .networks .embeddings import FourierEmbedding
2121
2222
2323# disable module check, use potential module after moving from experimental
2424@serializable ("bayesflow.networks" , disable_module_check = True )
25- class ContinuousTimeConsistencyModel (InferenceNetwork ):
26- """(IN) Implements an sCM (simple, stable, and scalable Consistency Model)
27- with continous-time Consistency Training (CT) as described in [1].
28- The sampling procedure is taken from [2].
25+ class StableConsistencyModel (InferenceNetwork ):
26+ """(IN) Implements an sCM (simple, stable, and scalable Consistency Model) with continuous-time Consistency Training
27+ (CT) as described in [1]. The sampling procedure is taken from [2].
2928
3029 [1] Lu, C., & Song, Y. (2024).
3130 Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models
3231 arXiv preprint arXiv:2410.11081
3332
3433 [2] Song, Y., Dhariwal, P., Chen, M. & Sutskever, I. (2023).
35- Consistency Models.
36- arXiv preprint arXiv:2303.01469
34+ Consistency Models. arXiv preprint arXiv:2303.01469
3735 """
3836
37+ MLP_DEFAULT_CONFIG = {
38+ "widths" : (256 , 256 , 256 , 256 , 256 ),
39+ "activation" : "mish" ,
40+ "kernel_initializer" : "he_normal" ,
41+ "residual" : True ,
42+ "dropout" : 0.05 ,
43+ "spectral_normalization" : False ,
44+ }
45+
46+ WEIGHT_MLP_DEFAULT_CONFIG = {
47+ "widths" : (256 ,),
48+ "activation" : "mish" ,
49+ "kernel_initializer" : "he_normal" ,
50+ "residual" : False ,
51+ "dropout" : 0.05 ,
52+ "spectral_normalization" : False ,
53+ }
54+
55+ EPS_WARN = 0.1
56+
3957 def __init__ (
4058 self ,
41- subnet : str | keras .Layer = "mlp" ,
42- sigma_data : float = 1.0 ,
59+ subnet : str | type | keras .Layer = "mlp" ,
60+ sigma : float = 1.0 ,
4361 subnet_kwargs : dict [str , any ] = None ,
62+ weight_mlp_kwargs : dict [str , any ] = None ,
4463 embedding_kwargs : dict [str , any ] = None ,
4564 ** kwargs ,
4665 ):
4766 """Creates an instance of an sCM to be used for consistency training (CT).
4867
4968 Parameters
5069 ----------
51- subnet : str or type, optional, default: "mlp"
52- A neural network type for the consistency model, will be
53- instantiated using subnet_kwargs.
54- sigma_data : float, optional, default: 1.0
55- Standard deviation of the target distribution
70+ subnet : str, type, or keras.Layer, optional, default="mlp"
71+ The neural network architecture used for the consistency model.
72+ If a string is provided, it should be a registered name (e.g., "mlp").
73+ If a type or keras.Layer is provided, it will be directly instantiated
74+ with the given ``subnet_kwargs``.
75+ sigma : float, optional, default=1.0
76+ Standard deviation of the target distribution for the consistency loss.
77+ Controls the scale of the noise injected during training.
78+ subnet_kwargs : dict[str, any], optional, default=None
79+ Keyword arguments passed to the constructor of the chosen ``subnet``. For example, number of hidden units,
80+ activation functions, or dropout settings.
81+ weight_mlp_kwargs : dict[str, any], optional, default=None
82+ Keyword arguments for an auxiliary MLP used to generate weights within the consistency model. Typically
83+ includes depth, hidden sizes, and non-linearity choices.
84+ embedding_kwargs : dict[str, any], optional, default=None
85+ Keyword arguments for the time embedding layer(s) used in the model
5686 **kwargs
57- Additional keyword arguments to the layer.
87+ Additional keyword arguments passed to the parent ``InferenceNetwork`` initializer
88+ (e.g., ``name``, ``dtype``, or ``trainable``).
5889 """
5990 super ().__init__ (base_distribution = "normal" , ** kwargs )
6091
6192 subnet_kwargs = subnet_kwargs or {}
62-
93+ if subnet == "mlp" :
94+ subnet_kwargs = StableConsistencyModel .MLP_DEFAULT_CONFIG | subnet_kwargs
6395 self .subnet = find_network (subnet , ** subnet_kwargs )
96+
6497 self .subnet_projector = keras .layers .Dense (
6598 units = None , bias_initializer = "zeros" , kernel_initializer = "zeros" , name = "subnet_projector"
6699 )
67100
68- self .weight_fn = MLP ([256 ], dropout = 0.0 )
101+ weight_mlp_kwargs = weight_mlp_kwargs or {}
102+ weight_mlp_kwargs = StableConsistencyModel .WEIGHT_MLP_DEFAULT_CONFIG | weight_mlp_kwargs
103+ self .weight_fn = MLP (** weight_mlp_kwargs )
104+
69105 self .weight_fn_projector = keras .layers .Dense (
70106 units = 1 , bias_initializer = "zeros" , kernel_initializer = "zeros" , name = "weight_fn_projector"
71107 )
@@ -74,8 +110,7 @@ def __init__(
74110 self .time_emb = FourierEmbedding (** embedding_kwargs )
75111 self .time_emb_dim = self .time_emb .embed_dim
76112
77- self .sigma_data = sigma_data
78-
113+ self .sigma = sigma
79114 self .seed_generator = keras .random .SeedGenerator ()
80115
81116 @classmethod
@@ -84,29 +119,33 @@ def from_config(cls, config, custom_objects=None):
84119
85120 def get_config (self ):
86121 base_config = super ().get_config ()
87- base_config = model_kwargs (base_config )
122+ base_config = layer_kwargs (base_config )
88123
89124 config = {
90125 "subnet" : self .subnet ,
91- "sigma_data " : self .sigma_data ,
126+ "sigma " : self .sigma ,
92127 }
93128
94129 return base_config | serialize (config )
95130
96131 def _discretize_time (self , num_steps : int , rho : float = 3.5 , ** kwargs ):
97- t = np . linspace (0.0 , np . pi / 2 , num_steps )
98- times = np . exp ((t - np . pi / 2 ) * rho ) * np . pi / 2
99- times [0 ] = 0.0
132+ t = keras . ops . linspace (0.0 , pi / 2 , num_steps )
133+ times = keras . ops . exp ((t - pi / 2 ) * rho ) * pi / 2
134+ times . at [0 ]. set ( 0.0 )
100135
101136 # if rho is set too low, bad schedules can occur
102- EPS_WARN = 0.1
103- if times [1 ] > EPS_WARN :
104- print ("Warning: The last time step is large." )
105- print (f"Increasing rho (was { rho } ) or n_steps (was { num_steps } ) might improve results." )
106- return ops .convert_to_tensor (times )
137+ if times [1 ] > StableConsistencyModel .EPS_WARN :
138+ logging .warning ("Warning: The last time step is large." )
139+ logging .warning (f"Increasing rho (was { rho } ) or n_steps (was { num_steps } ) might improve results." )
140+ return times
107141
108142 def build (self , xz_shape , conditions_shape = None ):
109- super ().build (xz_shape )
143+ if self .built :
144+ # building when the network is already built can cause issues with serialization
145+ # see https://github.com/keras-team/keras/issues/21147
146+ return
147+
148+ self .base_distribution .build (xz_shape )
110149 self .subnet_projector .units = xz_shape [- 1 ]
111150
112151 # construct input shape for subnet and subnet projector
@@ -134,17 +173,6 @@ def build(self, xz_shape, conditions_shape=None):
134173 input_shape = self .weight_fn .compute_output_shape (input_shape )
135174 self .weight_fn_projector .build (input_shape )
136175
137- def call (
138- self ,
139- xz : Tensor ,
140- conditions : Tensor = None ,
141- inverse : bool = False ,
142- ** kwargs ,
143- ):
144- if inverse :
145- return self ._inverse (xz , conditions = conditions , ** kwargs )
146- return self ._forward (xz , conditions = conditions , ** kwargs )
147-
148176 def _forward (self , x : Tensor , conditions : Tensor = None , ** kwargs ) -> Tensor :
149177 # Consistency Models only learn the direction from noise distribution
150178 # to target distribution, so we cannot implement this function.
@@ -172,8 +200,8 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
172200 steps = kwargs .get ("steps" , 15 )
173201 rho = kwargs .get ("rho" , 3.5 )
174202
175- # noise distribution has variance sigma_data
176- x = keras .ops .copy (z ) * self .sigma_data
203+ # noise distribution has variance sigma
204+ x = keras .ops .copy (z ) * self .sigma
177205 discretized_time = keras .ops .flip (self ._discretize_time (steps , rho = rho ), axis = - 1 )
178206 t = keras .ops .full ((* keras .ops .shape (x )[:- 1 ], 1 ), discretized_time [0 ], dtype = x .dtype )
179207 x = self .consistency_function (x , t , conditions = conditions )
@@ -207,9 +235,9 @@ def consistency_function(
207235 **kwargs : dict, optional, default: {}
208236 Additional keyword arguments passed to the inner network.
209237 """
210- xtc = concatenate_valid ([x / self .sigma_data , self .time_emb (t ), conditions ], axis = - 1 )
238+ xtc = concatenate_valid ([x / self .sigma , self .time_emb (t ), conditions ], axis = - 1 )
211239 f = self .subnet_projector (self .subnet (xtc , training = training , ** kwargs ))
212- out = ops .cos (t ) * x - ops .sin (t ) * self .sigma_data * f
240+ out = ops .cos (t ) * x - ops .sin (t ) * self .sigma * f
213241 return out
214242
215243 def compute_metrics (
@@ -226,17 +254,14 @@ def compute_metrics(
226254 c = 0.1
227255
228256 # generate noise vector
229- z = (
230- keras .random .normal (keras .ops .shape (x ), dtype = keras .ops .dtype (x ), seed = self .seed_generator )
231- * self .sigma_data
232- )
257+ z = keras .random .normal (keras .ops .shape (x ), dtype = keras .ops .dtype (x ), seed = self .seed_generator ) * self .sigma
233258
234259 # sample time
235260 tau = (
236261 keras .random .normal (keras .ops .shape (x )[:1 ], dtype = keras .ops .dtype (x ), seed = self .seed_generator ) * p_std
237262 + p_mean
238263 )
239- t_ = ops .arctan (ops .exp (tau ) / self .sigma_data )
264+ t_ = ops .arctan (ops .exp (tau ) / self .sigma )
240265 t = expand_right_as (t_ , x )
241266
242267 # generate noisy sample
@@ -251,23 +276,23 @@ def f_teacher(x, t):
251276 o = self .subnet (concatenate_valid ([x , self .time_emb (t ), conditions ], axis = - 1 ), training = stage == "training" )
252277 return self .subnet_projector (o )
253278
254- primals = (xt / self .sigma_data , t )
279+ primals = (xt / self .sigma , t )
255280 tangents = (
256281 ops .cos (t ) * ops .sin (t ) * dxtdt ,
257- ops .cos (t ) * ops .sin (t ) * self .sigma_data ,
282+ ops .cos (t ) * ops .sin (t ) * self .sigma ,
258283 )
259284
260285 teacher_output , cos_sin_dFdt = jvp (f_teacher , primals , tangents , return_output = True )
261286 teacher_output = ops .stop_gradient (teacher_output )
262287 cos_sin_dFdt = ops .stop_gradient (cos_sin_dFdt )
263288
264289 # calculate output of the network
265- xtc = concatenate_valid ([xt / self .sigma_data , self .time_emb (t ), conditions ], axis = - 1 )
290+ xtc = concatenate_valid ([xt / self .sigma , self .time_emb (t ), conditions ], axis = - 1 )
266291 student_out = self .subnet_projector (self .subnet (xtc , training = stage == "training" ))
267292
268293 # calculate the tangent
269- g = - (ops .cos (t ) ** 2 ) * (self .sigma_data * teacher_output - dxtdt ) - r * ops .cos (t ) * ops .sin (t ) * (
270- xt + self .sigma_data * cos_sin_dFdt
294+ g = - (ops .cos (t ) ** 2 ) * (self .sigma * teacher_output - dxtdt ) - r * ops .cos (t ) * ops .sin (t ) * (
295+ xt + self .sigma * cos_sin_dFdt
271296 )
272297
273298 # apply normalization to stabilize training
@@ -277,6 +302,7 @@ def f_teacher(x, t):
277302 w = self .weight_fn_projector (self .weight_fn (expand_right_to (t_ , 2 )))
278303
279304 D = ops .shape (x )[- 1 ]
305+
280306 loss = ops .mean (
281307 (ops .exp (w ) / D )
282308 * ops .mean (
0 commit comments