@@ -45,49 +45,52 @@ class DiffusionModel(InferenceNetwork):
4545
4646 INTEGRATE_DEFAULT_CONFIG = {
4747 "method" : "euler" ,
48- "steps" : 250 ,
48+ "steps" : 100 ,
4949 }
5050
5151 def __init__ (
5252 self ,
5353 * ,
5454 subnet : str | type = "mlp" ,
5555 noise_schedule : Literal ["edm" , "cosine" ] | NoiseSchedule | type = "edm" ,
56- prediction_type : Literal ["velocity" , "noise" , "F" ] = "F" ,
56+ prediction_type : Literal ["velocity" , "noise" , "F" , "x" ] = "F" ,
5757 loss_type : Literal ["velocity" , "noise" , "F" ] = "noise" ,
5858 subnet_kwargs : dict [str , any ] = None ,
5959 schedule_kwargs : dict [str , any ] = None ,
6060 integrate_kwargs : dict [str , any ] = None ,
6161 ** kwargs ,
6262 ):
6363 """
64- Initializes a diffusion model with configurable subnet architecture.
64+ Initializes a diffusion model with configurable subnet architecture, noise schedule,
65+ and prediction/loss types for amortized Bayesian inference.
6566
66- This model learns a transformation from a Gaussian latent distribution to a target distribution using a
67- specified subnet type, which can be an MLP or a custom network.
68-
69- The integration can be customized with additional parameters available in the integrate_kwargs
70- configuration dictionary. Different noise schedules and prediction types are available.
67+ Note, that score-based diffusion is the most sluggish of all available samplers,
68+ so expect slower inference times than flow matching and much slower than normalizing flows.
7169
7270 Parameters
7371 ----------
7472 subnet : str or type, optional
75- The architecture used for the transformation network. Can be "mlp" or a custom
76- callable network. Default is "mlp".
73+ Architecture for the transformation network. Can be "mlp" or a custom network class.
74+ Default is "mlp".
75+ noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional
76+ Noise schedule controlling the diffusion dynamics. Can be a string identifier,
77+ a schedule class, or a pre-initialized schedule instance. Default is "edm".
78+ prediction_type : {'velocity', 'noise', 'F', 'x'}, optional
79+ Output format of the model's prediction. Default is "F".
80+ loss_type : {'velocity', 'noise', 'F'}, optional
81+ Loss function used to train the model. Default is "noise".
82+ subnet_kwargs : dict[str, any], optional
83+ Additional keyword arguments passed to the subnet constructor. Default is None.
84+ schedule_kwargs : dict[str, any], optional
85+ Additional keyword arguments passed to the noise schedule constructor. Default is None.
7786 integrate_kwargs : dict[str, any], optional
78- Additional keyword arguments for the integration process. Default is None.
79- noise_schedule : Literal['edm', 'cosine'], dict or type, optional
80- The noise schedule used for the diffusion process. Default is "F"
81- loss_type: Literal['velocity', 'noise', 'F'], optional
82- The type los loss used in the diffusion model. Default is "noise".
83- prediction_type: Literal['velocity', 'noise', 'F'], optional
84- The type of prediction used in the diffusion model. Default is "F".
87+ Configuration dictionary for integration during training or inference. Default is None.
8588 **kwargs
86- Additional keyword arguments passed to the subnet and other components.
89+ Additional keyword arguments passed to the base class and internal components.
8790 """
8891 super ().__init__ (base_distribution = "normal" , ** kwargs )
8992
90- if prediction_type not in ["noise" , "velocity" , "F" ]:
93+ if prediction_type not in ["noise" , "velocity" , "F" , "x" ]:
9194 raise TypeError (f"Unknown prediction type: { prediction_type } " )
9295
9396 if loss_type not in ["noise" , "velocity" , "F" ]:
@@ -157,22 +160,42 @@ def from_config(cls, config, custom_objects=None):
157160 def convert_prediction_to_x (
158161 self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor
159162 ) -> Tensor :
160- """Convert the prediction of the neural network to the x space."""
163+ """
164+ Converts the neural network prediction into the denoised data `x`, depending on
165+ the prediction type configured for the model.
166+
167+ Parameters
168+ ----------
169+ pred : Tensor
170+ The output prediction from the neural network, typically representing noise,
171+ velocity, or a transformation of the clean signal.
172+ z : Tensor
173+ The noisy latent variable `z` to be denoised.
174+ alpha_t : Tensor
175+ The noise schedule's scaling factor for the clean signal at time `t`.
176+ sigma_t : Tensor
177+ The standard deviation of the noise at time `t`.
178+ log_snr_t : Tensor
179+ The log signal-to-noise ratio at time `t`.
180+
181+ Returns
182+ -------
183+ Tensor
184+ The reconstructed clean signal `x` from the model prediction.
185+ """
161186 if self ._prediction_type == "velocity" :
162- x = alpha_t * z - sigma_t * pred
187+ return alpha_t * z - sigma_t * pred
163188 elif self ._prediction_type == "noise" :
164- x = (z - sigma_t * pred ) / alpha_t
189+ return (z - sigma_t * pred ) / alpha_t
165190 elif self ._prediction_type == "F" :
166- sigma_data = self . noise_schedule . sigma_data if hasattr (self .noise_schedule , "sigma_data" ) else 1.0
191+ sigma_data = getattr (self .noise_schedule , "sigma_data" , 1.0 )
167192 x1 = (sigma_data ** 2 * alpha_t ) / (ops .exp (- log_snr_t ) + sigma_data ** 2 )
168193 x2 = ops .exp (- log_snr_t / 2 ) * sigma_data / ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 )
169- x = x1 * z + x2 * pred
194+ return x1 * z + x2 * pred
170195 elif self ._prediction_type == "x" :
171- x = pred
196+ return pred
172197 else :
173- x = (z + sigma_t ** 2 * pred ) / alpha_t
174-
175- return x
198+ return (z + sigma_t ** 2 * pred ) / alpha_t
176199
177200 def velocity (
178201 self ,
@@ -182,10 +205,37 @@ def velocity(
182205 conditions : Tensor = None ,
183206 training : bool = False ,
184207 ) -> Tensor :
208+ """
209+ Computes the velocity (i.e., time derivative) of the target or latent variable `xz` for either
210+ a stochastic differential equation (SDE) or ordinary differential equation (ODE).
211+
212+ Parameters
213+ ----------
214+ xz : Tensor
215+ The current state of the latent variable `z`, typically of shape (..., D),
216+ where D is the dimensionality of the latent space.
217+ time : float or Tensor
218+ Scalar or tensor representing the time (or noise level) at which the velocity
219+ should be computed. Will be broadcasted to xz.
220+ stochastic_solver : bool
221+ If True, computes the velocity for the stochastic formulation (SDE).
222+ If False, uses the deterministic formulation (ODE).
223+ conditions : Tensor, optional
224+ Optional conditional inputs to the network, such as conditioning variables
225+ or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
226+ training : bool, optional
227+ Whether the model is in training mode. Affects behavior of dropout, batch norm,
228+ or other stochastic layers. Default is False.
229+
230+ Returns
231+ -------
232+ Tensor
233+ The velocity tensor of the same shape as `xz`, representing the right-hand
234+ side of the SDE or ODE at the given `time`.
235+ """
185236 # calculate the current noise level and transform into correct shape
186237 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
187238 log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
188-
189239 alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
190240
191241 if conditions is None :
@@ -196,11 +246,11 @@ def velocity(
196246 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
197247
198248 x_pred = self .convert_prediction_to_x (pred = pred , z = xz , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t )
199- # convert x to score
249+
200250 score = (alpha_t * x_pred - xz ) / ops .square (sigma_t )
201251
202252 # compute velocity f, g of the SDE or ODE
203- f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz )
253+ f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz , training = training )
204254
205255 if stochastic_solver :
206256 # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW
@@ -218,12 +268,12 @@ def diffusion_term(
218268 training : bool = False ,
219269 ) -> Tensor :
220270 """
221- Compute the diffusion term (standard deviation of the noise) for a given time.
271+ Compute the diffusion term (standard deviation of the noise) at a given time.
222272
223273 Parameters
224274 ----------
225275 xz : Tensor
226- Input tensor with shape [ ..., D] , typically representing the latent state or concatenated variables.
276+ Input tensor of shape ( ..., D) , typically representing the target or latent variables at given time .
227277 time : float or Tensor
228278 The diffusion time step(s). Can be a scalar or a tensor broadcastable to the shape of `xz`.
229279 training : bool, optional
0 commit comments