Skip to content

Commit 3455ce1

Browse files
committed
rename prediction type
1 parent e32e8ad commit 3455ce1

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, name: str, variance_type: str):
4444
self.variance_type = variance_type # 'exploding' or 'preserving'
4545
self._log_snr_min = -15 # should be set in the subclasses
4646
self._log_snr_max = 15 # should be set in the subclasses
47+
self.sigma_data = 1.0
4748

4849
@property
4950
def scale_base_distribution(self):
@@ -381,7 +382,7 @@ def __init__(
381382
integrate_kwargs: dict[str, any] = None,
382383
subnet_kwargs: dict[str, any] = None,
383384
noise_schedule: str | NoiseSchedule = "cosine",
384-
prediction_type: str = "v",
385+
prediction_type: str = "velocity",
385386
**kwargs,
386387
):
387388
"""
@@ -406,7 +407,8 @@ def __init__(
406407
The noise schedule used for the diffusion process. Can be "linear", "cosine", or "edm".
407408
Default is "cosine".
408409
prediction_type: str, optional
409-
The type of prediction used in the diffusion model. Can be "eps", "v" or "F" (EDM). Default is "v".
410+
The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).
411+
Default is "velocity".
410412
**kwargs
411413
Additional keyword arguments passed to the subnet and other components.
412414
"""
@@ -427,7 +429,7 @@ def __init__(
427429
# validate noise model
428430
self.noise_schedule.validate()
429431

430-
if prediction_type not in ["eps", "v", "F"]: # F is EDM
432+
if prediction_type not in ["velocity", "noise", "F"]: # F is EDM
431433
raise ValueError(f"Unknown prediction type: {prediction_type}")
432434
self.prediction_type = prediction_type
433435

@@ -496,10 +498,10 @@ def convert_prediction_to_x(
496498
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor, clip_x: bool
497499
) -> Tensor:
498500
"""Convert the prediction of the neural network to the x space."""
499-
if self.prediction_type == "v":
501+
if self.prediction_type == "velocity":
500502
# convert v into x
501503
x = alpha_t * z - sigma_t * pred
502-
elif self.prediction_type == "eps":
504+
elif self.prediction_type == "noise":
503505
# convert noise prediction into x
504506
x = (z - sigma_t * pred) / alpha_t
505507
elif self.prediction_type == "x":
@@ -700,11 +702,11 @@ def compute_metrics(
700702
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, clip_x=False
701703
)
702704
# convert x to epsilon prediction
703-
out = (alpha_t * diffused_x - x_pred) / sigma_t
705+
noise_pred = (alpha_t * diffused_x - x_pred) / sigma_t
704706

705707
# Calculate loss based on noise prediction
706708
weights_for_snr = self.noise_schedule.get_weights_for_snr(log_snr_t=log_snr_t)
707-
loss = weights_for_snr * ops.mean((out - eps_t) ** 2, axis=-1)
709+
loss = weights_for_snr * ops.mean((noise_pred - eps_t) ** 2, axis=-1)
708710

709711
# apply sample weight
710712
loss = weighted_mean(loss, sample_weight)

0 commit comments

Comments
 (0)