Skip to content

Commit d8d6246

Browse files
committed
Merge branch 'feat-diffusion-model' of github.com:bayesflow-org/bayesflow into feat-diffusion-model
2 parents d82e2bf + 6031212 commit d8d6246

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Sequence
22
from abc import ABC, abstractmethod
3+
from typing import Union
34
import keras
45
from keras import ops
56
import warnings
@@ -61,7 +62,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
6162
pass
6263

6364
@abstractmethod
64-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
65+
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
6566
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6667
pass
6768

@@ -141,7 +142,7 @@ class LinearNoiseSchedule(NoiseSchedule):
141142
"""
142143

143144
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
144-
super().__init__(name="linear_noise_schedule")
145+
super().__init__(name="linear_noise_schedule", variance_type="preserving")
145146
self._log_snr_min = min_log_snr
146147
self._log_snr_max = max_log_snr
147148

@@ -154,7 +155,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
154155
# SNR = -log(exp(t^2) - 1)
155156
return -ops.log(ops.exp(ops.square(t_trunc)) - 1)
156157

157-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
158+
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
158159
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
159160
# SNR = -log(exp(t^2) - 1) => t = sqrt(log(1 + exp(-snr)))
160161
return ops.sqrt(ops.log(1 + ops.exp(-log_snr_t)))
@@ -213,7 +214,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
213214
# SNR = -2 * log(tan(pi*t/2))
214215
return -2 * ops.log(ops.tan(math.pi * t_trunc / 2)) + 2 * self._s_shift_cosine
215216

216-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
217+
def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
217218
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
218219
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
219220
return 2 / math.pi * ops.arctan(ops.exp((2 * self._s_shift_cosine - log_snr_t) / 2))
@@ -289,7 +290,7 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
289290
)
290291
return snr
291292

292-
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
293+
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
293294
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
294295
if training:
295296
# SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr)
@@ -551,8 +552,8 @@ def _forward(
551552
) -> Tensor | tuple[Tensor, Tensor]:
552553
integrate_kwargs = (
553554
{
554-
"start_time": self.noise_schedule._t_min,
555-
"stop_time": self.noise_schedule._t_max,
555+
"start_time": 1.0,
556+
"stop_time": 0.0,
556557
}
557558
| self.integrate_kwargs
558559
| kwargs
@@ -600,8 +601,8 @@ def _inverse(
600601
) -> Tensor | tuple[Tensor, Tensor]:
601602
integrate_kwargs = (
602603
{
603-
"start_time": self.noise_schedule._t_max,
604-
"stop_time": self.noise_schedule._t_min,
604+
"start_time": 1.0,
605+
"stop_time": 0.0,
605606
}
606607
| self.integrate_kwargs
607608
| kwargs

0 commit comments

Comments
 (0)