Skip to content

Commit a387ed9

Browse files
committed
add serialization tests, minor fix to cosine schedule config
1 parent b9a7e36 commit a387ed9

File tree

5 files changed

+53
-2
lines changed

5 files changed

+53
-2
lines changed

bayesflow/experimental/diffusion_model/cosine_noise_schedule.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
"""
4242
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
4343
self._shift = shift
44+
self._weighting = weighting
4445
self.log_snr_min = min_log_snr
4546
self.log_snr_max = max_log_snr
4647

@@ -75,7 +76,9 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
7576
return -factor * dsnr_dt
7677

7778
def get_config(self):
78-
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift)
79+
return dict(
80+
min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift, weighting=self._weighting
81+
)
7982

8083
@classmethod
8184
def from_config(cls, config, custom_objects=None):

bayesflow/experimental/diffusion_model/noise_schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
138138
raise TypeError(f"Unknown weighting type: {self._weighting}")
139139

140140
def get_config(self):
141-
return dict(name=self.name, variance_type=self._variance_type)
141+
return dict(name=self.name, variance_type=self._variance_type, weighting=self._weighting)
142142

143143
@classmethod
144144
def from_config(cls, config, custom_objects=None):

tests/test_networks/test_diffusion_model/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
4+
@pytest.fixture()
5+
def cosine_noise_schedule():
6+
from bayesflow.experimental.diffusion_model import CosineNoiseSchedule
7+
8+
return CosineNoiseSchedule(min_log_snr=-12, max_log_snr=12, shift=0.1, weighting="likelihood_weighting")
9+
10+
11+
@pytest.fixture()
12+
def edm_noise_schedule():
13+
from bayesflow.experimental.diffusion_model import EDMNoiseSchedule
14+
15+
return EDMNoiseSchedule(sigma_data=10.0, sigma_min=1e-5, sigma_max=85.0)
16+
17+
18+
@pytest.fixture(
19+
params=["cosine_noise_schedule", "edm_noise_schedule"],
20+
scope="function",
21+
)
22+
def noise_schedule(request):
23+
return request.getfixturevalue(request.param)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
def test_serialize_deserialize_noise_schedule(noise_schedule):
2+
from bayesflow.utils.serialization import serialize, deserialize
3+
4+
serialized = serialize(noise_schedule)
5+
deserialized = deserialize(serialized)
6+
reserialized = serialize(deserialized)
7+
8+
assert serialized == reserialized
9+
t = 0.251
10+
x = 0.5
11+
training = True
12+
assert noise_schedule.get_log_snr(t, training=training) == deserialized.get_log_snr(t, training=training)
13+
assert noise_schedule.get_t_from_log_snr(t, training=training) == deserialized.get_t_from_log_snr(
14+
t, training=training
15+
)
16+
assert noise_schedule.derivative_log_snr(t, training=False) == deserialized.derivative_log_snr(t, training=False)
17+
assert noise_schedule.get_drift_diffusion(t, x, training=False) == deserialized.get_drift_diffusion(
18+
t, x, training=False
19+
)
20+
assert noise_schedule.get_alpha_sigma(t, training=training) == deserialized.get_alpha_sigma(t, training=training)
21+
assert noise_schedule.get_weights_for_snr(t) == deserialized.get_weights_for_snr(t)
22+
23+
24+
def test_validate_noise_schedule(noise_schedule):
25+
noise_schedule.validate()

0 commit comments

Comments
 (0)