Skip to content

Commit 877a720

Browse files
committed
Improve docs and mark option for x prediction in literal
1 parent f740784 commit 877a720

File tree

1 file changed

+83
-33
lines changed

1 file changed

+83
-33
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -45,49 +45,52 @@ class DiffusionModel(InferenceNetwork):
4545

4646
INTEGRATE_DEFAULT_CONFIG = {
4747
"method": "euler",
48-
"steps": 250,
48+
"steps": 100,
4949
}
5050

5151
def __init__(
5252
self,
5353
*,
5454
subnet: str | type = "mlp",
5555
noise_schedule: Literal["edm", "cosine"] | NoiseSchedule | type = "edm",
56-
prediction_type: Literal["velocity", "noise", "F"] = "F",
56+
prediction_type: Literal["velocity", "noise", "F", "x"] = "F",
5757
loss_type: Literal["velocity", "noise", "F"] = "noise",
5858
subnet_kwargs: dict[str, any] = None,
5959
schedule_kwargs: dict[str, any] = None,
6060
integrate_kwargs: dict[str, any] = None,
6161
**kwargs,
6262
):
6363
"""
64-
Initializes a diffusion model with configurable subnet architecture.
64+
Initializes a diffusion model with configurable subnet architecture, noise schedule,
65+
and prediction/loss types for amortized Bayesian inference.
6566
66-
This model learns a transformation from a Gaussian latent distribution to a target distribution using a
67-
specified subnet type, which can be an MLP or a custom network.
68-
69-
The integration can be customized with additional parameters available in the integrate_kwargs
70-
configuration dictionary. Different noise schedules and prediction types are available.
67+
Note, that score-based diffusion is the most sluggish of all available samplers,
68+
so expect slower inference times than flow matching and much slower than normalizing flows.
7169
7270
Parameters
7371
----------
7472
subnet : str or type, optional
75-
The architecture used for the transformation network. Can be "mlp" or a custom
76-
callable network. Default is "mlp".
73+
Architecture for the transformation network. Can be "mlp" or a custom network class.
74+
Default is "mlp".
75+
noise_schedule : {'edm', 'cosine'} or NoiseSchedule or type, optional
76+
Noise schedule controlling the diffusion dynamics. Can be a string identifier,
77+
a schedule class, or a pre-initialized schedule instance. Default is "edm".
78+
prediction_type : {'velocity', 'noise', 'F', 'x'}, optional
79+
Output format of the model's prediction. Default is "F".
80+
loss_type : {'velocity', 'noise', 'F'}, optional
81+
Loss function used to train the model. Default is "noise".
82+
subnet_kwargs : dict[str, any], optional
83+
Additional keyword arguments passed to the subnet constructor. Default is None.
84+
schedule_kwargs : dict[str, any], optional
85+
Additional keyword arguments passed to the noise schedule constructor. Default is None.
7786
integrate_kwargs : dict[str, any], optional
78-
Additional keyword arguments for the integration process. Default is None.
79-
noise_schedule : Literal['edm', 'cosine'], dict or type, optional
80-
The noise schedule used for the diffusion process. Default is "F"
81-
loss_type: Literal['velocity', 'noise', 'F'], optional
82-
The type los loss used in the diffusion model. Default is "noise".
83-
prediction_type: Literal['velocity', 'noise', 'F'], optional
84-
The type of prediction used in the diffusion model. Default is "F".
87+
Configuration dictionary for integration during training or inference. Default is None.
8588
**kwargs
86-
Additional keyword arguments passed to the subnet and other components.
89+
Additional keyword arguments passed to the base class and internal components.
8790
"""
8891
super().__init__(base_distribution="normal", **kwargs)
8992

90-
if prediction_type not in ["noise", "velocity", "F"]:
93+
if prediction_type not in ["noise", "velocity", "F", "x"]:
9194
raise TypeError(f"Unknown prediction type: {prediction_type}")
9295

9396
if loss_type not in ["noise", "velocity", "F"]:
@@ -157,22 +160,42 @@ def from_config(cls, config, custom_objects=None):
157160
def convert_prediction_to_x(
158161
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor
159162
) -> Tensor:
160-
"""Convert the prediction of the neural network to the x space."""
163+
"""
164+
Converts the neural network prediction into the denoised data `x`, depending on
165+
the prediction type configured for the model.
166+
167+
Parameters
168+
----------
169+
pred : Tensor
170+
The output prediction from the neural network, typically representing noise,
171+
velocity, or a transformation of the clean signal.
172+
z : Tensor
173+
The noisy latent variable `z` to be denoised.
174+
alpha_t : Tensor
175+
The noise schedule's scaling factor for the clean signal at time `t`.
176+
sigma_t : Tensor
177+
The standard deviation of the noise at time `t`.
178+
log_snr_t : Tensor
179+
The log signal-to-noise ratio at time `t`.
180+
181+
Returns
182+
-------
183+
Tensor
184+
The reconstructed clean signal `x` from the model prediction.
185+
"""
161186
if self._prediction_type == "velocity":
162-
x = alpha_t * z - sigma_t * pred
187+
return alpha_t * z - sigma_t * pred
163188
elif self._prediction_type == "noise":
164-
x = (z - sigma_t * pred) / alpha_t
189+
return (z - sigma_t * pred) / alpha_t
165190
elif self._prediction_type == "F":
166-
sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0
191+
sigma_data = getattr(self.noise_schedule, "sigma_data", 1.0)
167192
x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2)
168193
x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)
169-
x = x1 * z + x2 * pred
194+
return x1 * z + x2 * pred
170195
elif self._prediction_type == "x":
171-
x = pred
196+
return pred
172197
else:
173-
x = (z + sigma_t**2 * pred) / alpha_t
174-
175-
return x
198+
return (z + sigma_t**2 * pred) / alpha_t
176199

177200
def velocity(
178201
self,
@@ -182,10 +205,37 @@ def velocity(
182205
conditions: Tensor = None,
183206
training: bool = False,
184207
) -> Tensor:
208+
"""
209+
Computes the velocity (i.e., time derivative) of the target or latent variable `xz` for either
210+
a stochastic differential equation (SDE) or ordinary differential equation (ODE).
211+
212+
Parameters
213+
----------
214+
xz : Tensor
215+
The current state of the latent variable `z`, typically of shape (..., D),
216+
where D is the dimensionality of the latent space.
217+
time : float or Tensor
218+
Scalar or tensor representing the time (or noise level) at which the velocity
219+
should be computed. Will be broadcasted to xz.
220+
stochastic_solver : bool
221+
If True, computes the velocity for the stochastic formulation (SDE).
222+
If False, uses the deterministic formulation (ODE).
223+
conditions : Tensor, optional
224+
Optional conditional inputs to the network, such as conditioning variables
225+
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
226+
training : bool, optional
227+
Whether the model is in training mode. Affects behavior of dropout, batch norm,
228+
or other stochastic layers. Default is False.
229+
230+
Returns
231+
-------
232+
Tensor
233+
The velocity tensor of the same shape as `xz`, representing the right-hand
234+
side of the SDE or ODE at the given `time`.
235+
"""
185236
# calculate the current noise level and transform into correct shape
186237
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
187238
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
188-
189239
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
190240

191241
if conditions is None:
@@ -196,11 +246,11 @@ def velocity(
196246
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
197247

198248
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
199-
# convert x to score
249+
200250
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
201251

202252
# compute velocity f, g of the SDE or ODE
203-
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
253+
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training)
204254

205255
if stochastic_solver:
206256
# for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW
@@ -218,12 +268,12 @@ def diffusion_term(
218268
training: bool = False,
219269
) -> Tensor:
220270
"""
221-
Compute the diffusion term (standard deviation of the noise) for a given time.
271+
Compute the diffusion term (standard deviation of the noise) at a given time.
222272
223273
Parameters
224274
----------
225275
xz : Tensor
226-
Input tensor with shape [..., D], typically representing the latent state or concatenated variables.
276+
Input tensor of shape (..., D), typically representing the target or latent variables at given time.
227277
time : float or Tensor
228278
The diffusion time step(s). Can be a scalar or a tensor broadcastable to the shape of `xz`.
229279
training : bool, optional

0 commit comments

Comments
 (0)