Skip to content

Commit 194a503

Browse files
committed
fix scale base dist
1 parent 548f51b commit 194a503

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def scale_base_distribution(self):
5454
return 1.0
5555
elif self.variance_type == "exploding":
5656
# e.g., EDM is a variance exploding schedule
57-
return ops.exp(-self._log_snr_min)
57+
return ops.sqrt(ops.exp(-self._log_snr_min))
5858
else:
5959
raise ValueError(f"Unknown variance type: {self.variance_type}")
6060

@@ -279,17 +279,20 @@ class EDMNoiseSchedule(NoiseSchedule):
279279
def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80):
280280
super().__init__(name="edm_noise_schedule", variance_type="exploding")
281281
self.sigma_data = sigma_data
282-
self.sigma_max = sigma_max
283-
self.sigma_min = sigma_min
282+
# training settings
284283
self.p_mean = -1.2
285284
self.p_std = 1.2
285+
# sampling settings
286+
self.sigma_max = sigma_max
287+
self.sigma_min = sigma_min
286288
self.rho = 7
287289

288290
# convert EDM parameters to signal-to-noise ratio formulation
289291
self._log_snr_min = -2 * ops.log(sigma_max)
290292
self._log_snr_max = -2 * ops.log(sigma_min)
291-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
292-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
293+
# t is not truncated for EDM by definition of the sampling schedule
294+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=False)
295+
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=False)
293296

294297
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
295298
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""

0 commit comments

Comments
 (0)