Skip to content

Commit 6031212

Browse files
committed
integration should be from 1 to 0
1 parent 01b33dc commit 6031212

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Sequence
22
from abc import ABC, abstractmethod
3+
from typing import Union
34
import keras
45
from keras import ops
56

@@ -60,7 +61,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
6061
pass
6162

6263
@abstractmethod
63-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
64+
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
6465
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6566
pass
6667

@@ -140,7 +141,7 @@ class LinearNoiseSchedule(NoiseSchedule):
140141
"""
141142

142143
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
143-
super().__init__(name="linear_noise_schedule")
144+
super().__init__(name="linear_noise_schedule", variance_type="preserving")
144145
self._log_snr_min = min_log_snr
145146
self._log_snr_max = max_log_snr
146147

@@ -153,7 +154,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
153154
# SNR = -log(exp(t^2) - 1)
154155
return -ops.log(ops.exp(ops.square(t_trunc)) - 1)
155156

156-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
157+
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
157158
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
158159
# SNR = -log(exp(t^2) - 1) => t = sqrt(log(1 + exp(-snr)))
159160
return ops.sqrt(ops.log(1 + ops.exp(-log_snr_t)))
@@ -212,7 +213,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
212213
# SNR = -2 * log(tan(pi*t/2))
213214
return -2 * ops.log(ops.tan(math.pi * t_trunc / 2)) + 2 * self._s_shift_cosine
214215

215-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
216+
def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
216217
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
217218
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
218219
return 2 / math.pi * ops.arctan(ops.exp((2 * self._s_shift_cosine - log_snr_t) / 2))
@@ -288,7 +289,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
288289
)
289290
return snr
290291

291-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
292+
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
292293
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
293294
if training:
294295
# SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr)
@@ -543,8 +544,8 @@ def _forward(
543544
) -> Tensor | tuple[Tensor, Tensor]:
544545
integrate_kwargs = (
545546
{
546-
"start_time": self.noise_schedule._t_min,
547-
"stop_time": self.noise_schedule._t_max,
547+
"start_time": 1.0,
548+
"stop_time": 0.0,
548549
}
549550
| self.integrate_kwargs
550551
| kwargs
@@ -592,8 +593,8 @@ def _inverse(
592593
) -> Tensor | tuple[Tensor, Tensor]:
593594
integrate_kwargs = (
594595
{
595-
"start_time": self.noise_schedule._t_max,
596-
"stop_time": self.noise_schedule._t_min,
596+
"start_time": 1.0,
597+
"stop_time": 0.0,
597598
}
598599
| self.integrate_kwargs
599600
| kwargs

0 commit comments

Comments
 (0)