Skip to content

Commit 7f67c58

Browse files
committed
diffusion model input
1 parent f916855 commit 7f67c58

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
if subnet == "mlp":
117117
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
118118
self.subnet = find_network(subnet, **subnet_kwargs)
119+
self._concatenate_subnet_input = subnet_kwargs.get("concatenate_subnet_input", True)
119120

120121
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
121122

@@ -149,6 +150,7 @@ def get_config(self):
149150
"prediction_type": self._prediction_type,
150151
"loss_type": self._loss_type,
151152
"integrate_kwargs": self.integrate_kwargs,
153+
"_concatenate_subnet_input": self._concatenate_subnet_input,
152154
}
153155
return base_config | serialize(config)
154156

@@ -197,6 +199,33 @@ def convert_prediction_to_x(
197199
return (z + sigma_t**2 * pred) / alpha_t
198200
raise ValueError(f"Unknown prediction type {self._prediction_type}.")
199201

202+
def prepare_subnet_input(self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None) -> Tensor:
203+
"""
204+
Prepares the input for the subnet either by concatenating the latent variable `xz`,
205+
the log signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.
206+
207+
Parameters
208+
----------
209+
xz : Tensor
210+
The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
211+
log_snr : Tensor
212+
The log signal-to-noise ratio tensor, typically of shape (..., 1).
213+
conditions : Tensor, optional
214+
The optional conditioning tensor (e.g. parameters).
215+
216+
Returns
217+
-------
218+
Tensor
219+
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
220+
"""
221+
if self._concatenate_subnet_input:
222+
if conditions is None:
223+
return tensor_utils.concatenate_valid([xz, log_snr], axis=-1)
224+
else:
225+
return tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
226+
else:
227+
return xz, log_snr, conditions
228+
200229
def velocity(
201230
self,
202231
xz: Tensor,
@@ -221,7 +250,7 @@ def velocity(
221250
If True, computes the velocity for the stochastic formulation (SDE).
222251
If False, uses the deterministic formulation (ODE).
223252
conditions : Tensor, optional
224-
Optional conditional inputs to the network, such as conditioning variables
253+
Conditional inputs to the network, such as conditioning variables
225254
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
226255
training : bool, optional
227256
Whether the model is in training mode. Affects behavior of dropout, batch norm,
@@ -238,11 +267,7 @@ def velocity(
238267
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
239268
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
240269

241-
if conditions is None:
242-
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1)
243-
else:
244-
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)
245-
270+
xtc = self.prepare_subnet_input(xz, self._transform_log_snr(log_snr_t), conditions=conditions)
246271
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
247272

248273
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
@@ -461,10 +486,7 @@ def compute_metrics(
461486
diffused_x = alpha_t * x + sigma_t * eps_t
462487

463488
# calculate output of the network
464-
if conditions is None:
465-
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1)
466-
else:
467-
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1)
489+
xtc = self.prepare_subnet_input(diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions)
468490
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
469491

470492
x_pred = self.convert_prediction_to_x(

0 commit comments

Comments
 (0)