Skip to content

Commit b9a7e36

Browse files
committed
restructure noise schedules, test find_noise_schedule
1 parent f4f0d11 commit b9a7e36

File tree

6 files changed

+164
-86
lines changed

6 files changed

+164
-86
lines changed

bayesflow/experimental/diffusion_model/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from .diffusion_model import DiffusionModel
2-
from .noise_schedules import CosineNoiseSchedule, EDMNoiseSchedule, NoiseSchedule
2+
from .noise_schedule import NoiseSchedule
3+
from .cosine_noise_schedule import CosineNoiseSchedule
4+
from .edm_noise_schedule import EDMNoiseSchedule
5+
from .dispatch import find_noise_schedule
36

47
from ...utils._docs import _add_imports_to_all
58

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import math
2+
from typing import Union, Literal
3+
4+
from keras import ops
5+
6+
from bayesflow.types import Tensor
7+
from bayesflow.utils.serialization import deserialize, serializable
8+
9+
from .noise_schedule import NoiseSchedule
10+
11+
12+
# disable module check, use potential module after moving from experimental
13+
@serializable("bayesflow.networks", disable_module_check=True)
14+
class CosineNoiseSchedule(NoiseSchedule):
15+
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
16+
17+
[1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
18+
"""
19+
20+
def __init__(
21+
self,
22+
min_log_snr: float = -15,
23+
max_log_snr: float = 15,
24+
shift: float = 0.0,
25+
weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid",
26+
):
27+
"""
28+
Initialize the cosine noise schedule.
29+
30+
Parameters
31+
----------
32+
min_log_snr : float, optional
33+
The minimum log signal-to-noise ratio (lambda). Default is -15.
34+
max_log_snr : float, optional
35+
The maximum log signal-to-noise ratio (lambda). Default is 15.
36+
shift : float, optional
37+
Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0.
38+
For images, use shift = log(base_resolution / d), where d is the used resolution of the image.
39+
weighting : Literal["sigmoid", "likelihood_weighting"], optional
40+
The type of weighting function to use for the noise schedule. Default is "sigmoid".
41+
"""
42+
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
43+
self._shift = shift
44+
self.log_snr_min = min_log_snr
45+
self.log_snr_max = max_log_snr
46+
47+
self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True)
48+
self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True)
49+
50+
def _truncated_t(self, t: Tensor) -> Tensor:
51+
return self._t_min + (self._t_max - self._t_min) * t
52+
53+
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
54+
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
55+
t_trunc = self._truncated_t(t)
56+
return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift
57+
58+
def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
59+
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
60+
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
61+
return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5))
62+
63+
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
64+
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
65+
t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training)
66+
67+
# Compute the truncated time t_trunc
68+
t_trunc = self._truncated_t(t)
69+
dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc)
70+
71+
# Using the chain rule on f(t) = log(1 + e^(-snr(t))):
72+
# f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
73+
dsnr_dt = dsnr_dx * (self._t_max - self._t_min)
74+
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
75+
return -factor * dsnr_dt
76+
77+
def get_config(self):
78+
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift)
79+
80+
@classmethod
81+
def from_config(cls, config, custom_objects=None):
82+
return cls(**deserialize(config, custom_objects=custom_objects))

bayesflow/experimental/diffusion_model/dispatch.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from functools import singledispatch
2-
from .noise_schedules import NoiseSchedule
2+
from .noise_schedule import NoiseSchedule
33

44

55
@singledispatch
66
def find_noise_schedule(arg, *args, **kwargs):
7-
raise TypeError(f"Unknown noise schedule: {arg!r}")
7+
raise TypeError(f"Not a noise schedule: {arg!r}. Please pass an object of type 'NoiseSchedule'.")
88

99

1010
@find_noise_schedule.register
@@ -16,11 +16,11 @@ def _(noise_schedule: NoiseSchedule):
1616
def _(name: str, *args, **kwargs):
1717
match name.lower():
1818
case "cosine":
19-
from .noise_schedules import CosineNoiseSchedule
19+
from .cosine_noise_schedule import CosineNoiseSchedule
2020

2121
return CosineNoiseSchedule()
2222
case "edm":
23-
from .noise_schedules import EDMNoiseSchedule
23+
from .edm_noise_schedule import EDMNoiseSchedule
2424

2525
return EDMNoiseSchedule()
2626
case other:
@@ -33,11 +33,11 @@ def _(config: dict, *args, **kwargs):
3333
params = {k: v for k, v in config.items() if k != "name"}
3434
match name:
3535
case "cosine":
36-
from .noise_schedules import CosineNoiseSchedule
36+
from .cosine_noise_schedule import CosineNoiseSchedule
3737

3838
return CosineNoiseSchedule(**params)
3939
case "edm":
40-
from .noise_schedules import EDMNoiseSchedule
40+
from .edm_noise_schedule import EDMNoiseSchedule
4141

4242
return EDMNoiseSchedule(**params)
4343
case other:
@@ -49,8 +49,3 @@ def _(cls: type, *args, **kwargs):
4949
if issubclass(cls, NoiseSchedule):
5050
return cls(*args, **kwargs)
5151
raise TypeError(f"Expected subclass of NoiseSchedule, got {cls}")
52-
53-
54-
@find_noise_schedule.register
55-
def _(schedule: type, *args, **kwargs):
56-
return schedule

bayesflow/experimental/diffusion_model/noise_schedules.py renamed to bayesflow/experimental/diffusion_model/edm_noise_schedule.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,12 @@
11
import math
2-
from typing import Union, Literal
2+
from typing import Union
33

44
from keras import ops
55

66
from bayesflow.types import Tensor
77
from bayesflow.utils.serialization import deserialize, serializable
88

9-
from .noise_schedule_base import NoiseSchedule
10-
11-
12-
@serializable("bayesflow.experimental")
13-
class CosineNoiseSchedule(NoiseSchedule):
14-
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
15-
16-
[1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
17-
"""
18-
19-
def __init__(
20-
self,
21-
min_log_snr: float = -15,
22-
max_log_snr: float = 15,
23-
shift: float = 0.0,
24-
weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid",
25-
):
26-
"""
27-
Initialize the cosine noise schedule.
28-
29-
Parameters
30-
----------
31-
min_log_snr : float, optional
32-
The minimum log signal-to-noise ratio (lambda). Default is -15.
33-
max_log_snr : float, optional
34-
The maximum log signal-to-noise ratio (lambda). Default is 15.
35-
shift : float, optional
36-
Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0.
37-
For images, use shift = log(base_resolution / d), where d is the used resolution of the image.
38-
weighting : Literal["sigmoid", "likelihood_weighting"], optional
39-
The type of weighting function to use for the noise schedule. Default is "sigmoid".
40-
"""
41-
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
42-
self._shift = shift
43-
self.log_snr_min = min_log_snr
44-
self.log_snr_max = max_log_snr
45-
46-
self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True)
47-
self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True)
48-
49-
def _truncated_t(self, t: Tensor) -> Tensor:
50-
return self._t_min + (self._t_max - self._t_min) * t
51-
52-
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
53-
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
54-
t_trunc = self._truncated_t(t)
55-
return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift
56-
57-
def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
58-
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
59-
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
60-
return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5))
61-
62-
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
63-
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
64-
t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training)
65-
66-
# Compute the truncated time t_trunc
67-
t_trunc = self._truncated_t(t)
68-
dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc)
69-
70-
# Using the chain rule on f(t) = log(1 + e^(-snr(t))):
71-
# f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
72-
dsnr_dt = dsnr_dx * (self._t_max - self._t_min)
73-
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
74-
return -factor * dsnr_dt
75-
76-
def get_config(self):
77-
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift)
78-
79-
@classmethod
80-
def from_config(cls, config, custom_objects=None):
81-
return cls(**deserialize(config, custom_objects=custom_objects))
9+
from .noise_schedule import NoiseSchedule
8210

8311

8412
# disable module check, use potential module after moving from experimental

tests/test_utils/test_dispatch.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from bayesflow.utils import find_inference_network, find_distribution, find_summary_network
5+
from bayesflow.experimental.diffusion_model import find_noise_schedule
56

67

78
# --- Tests for find_inference_network.py ---
@@ -168,3 +169,72 @@ def test_find_summary_network_unknown_name():
168169
def test_find_summary_network_invalid_type():
169170
with pytest.raises(TypeError):
170171
find_summary_network(0.1234)
172+
173+
174+
def test_find_noise_schedule_by_name():
175+
from bayesflow.experimental.diffusion_model import CosineNoiseSchedule, EDMNoiseSchedule
176+
177+
schedule = find_noise_schedule("cosine")
178+
assert isinstance(schedule, CosineNoiseSchedule)
179+
180+
schedule = find_noise_schedule("edm")
181+
assert isinstance(schedule, EDMNoiseSchedule)
182+
183+
184+
def test_find_noise_schedule_unknown_name():
185+
with pytest.raises(ValueError):
186+
find_noise_schedule("unknown_noise_schedule")
187+
188+
189+
def test_pass_noise_schedule():
190+
from bayesflow.experimental.diffusion_model import NoiseSchedule
191+
192+
class CustomNoiseSchedule(NoiseSchedule):
193+
def __init__(self):
194+
pass
195+
196+
def get_log_snr(self, t, training):
197+
pass
198+
199+
def get_t_from_log_snr(self, log_snr_t, training):
200+
pass
201+
202+
def derivative_log_snr(self, log_snr_t, training):
203+
pass
204+
205+
schedule = CustomNoiseSchedule()
206+
assert schedule is find_noise_schedule(schedule)
207+
208+
209+
def test_pass_noise_schedule_type():
210+
from bayesflow.experimental.diffusion_model import EDMNoiseSchedule
211+
212+
schedule = find_noise_schedule(EDMNoiseSchedule, sigma_data=10.0)
213+
assert isinstance(schedule, EDMNoiseSchedule)
214+
assert schedule.sigma_data == 10.0
215+
216+
217+
def test_find_noise_schedule_by_dict():
218+
from bayesflow.experimental.diffusion_model import CosineNoiseSchedule, EDMNoiseSchedule
219+
220+
schedule = find_noise_schedule({"name": "cosine"})
221+
assert isinstance(schedule, CosineNoiseSchedule)
222+
223+
schedule = find_noise_schedule({"name": "edm", "sigma_data": 10})
224+
assert isinstance(schedule, EDMNoiseSchedule)
225+
assert schedule.sigma_data == 10
226+
227+
228+
def test_find_noise_schedule_unknown_name_in_dict():
229+
with pytest.raises(ValueError):
230+
find_noise_schedule({"name": "unknown_noise_schedule"})
231+
232+
233+
def test_find_noise_schedule_invalid_class():
234+
with pytest.raises(TypeError):
235+
find_noise_schedule(int)
236+
237+
238+
def test_find_noise_schedule_invalid_object():
239+
with pytest.raises(TypeError):
240+
find_noise_schedule(1.0)

0 commit comments

Comments
 (0)