Skip to content

Commit d43ea98

Browse files
committed
Move all files into a separate directory, fix serialization
1 parent 0e5a48c commit d43ea98

File tree

7 files changed

+73
-59
lines changed

7 files changed

+73
-59
lines changed

bayesflow/experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99

1010
from ..utils._docs import _add_imports_to_all
1111

12-
_add_imports_to_all(include_modules=[])
12+
_add_imports_to_all(include_modules=["diffusion_model"])
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .diffusion_model import DiffusionModel
2+
from .noise_schedules import CosineNoiseSchedule, EDMNoiseSchedule, NoiseSchedule
3+
4+
from ...utils._docs import _add_imports_to_all
5+
6+
_add_imports_to_all(include_modules=[])

bayesflow/experimental/diffusion_model.py renamed to bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
weighted_mean,
1515
integrate,
1616
integrate_stochastic,
17-
find_noise_schedule,
1817
logging,
1918
tensor_utils,
2019
)
20+
from .dispatch import find_noise_schedule
2121
from bayesflow.utils.serialization import serialize, deserialize, serializable
2222

2323

24-
@serializable("bayesflow.experimental")
24+
# disable module check, use potential module after moving from experimental
25+
@serializable("bayesflow.networks", disable_module_check=True)
2526
class DiffusionModel(InferenceNetwork):
2627
"""Diffusion Model as described in this overview paper [1].
2728
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from functools import singledispatch
2+
from .noise_schedules import NoiseSchedule
3+
4+
5+
@singledispatch
6+
def find_noise_schedule(arg, *args, **kwargs):
7+
raise TypeError(f"Unknown noise schedule: {arg!r}")
8+
9+
10+
@find_noise_schedule.register
11+
def _(noise_schedule: NoiseSchedule):
12+
return noise_schedule
13+
14+
15+
@find_noise_schedule.register
16+
def _(name: str, *args, **kwargs):
17+
match name.lower():
18+
case "cosine":
19+
from .noise_schedules import CosineNoiseSchedule
20+
21+
return CosineNoiseSchedule()
22+
case "edm":
23+
from .noise_schedules import EDMNoiseSchedule
24+
25+
return EDMNoiseSchedule()
26+
case other:
27+
raise ValueError(f"Unsupported noise schedule name: '{other}'.")
28+
29+
30+
@find_noise_schedule.register
31+
def _(config: dict, *args, **kwargs):
32+
name = config.get("type", "").lower()
33+
params = {k: v for k, v in config.items() if k != "type"}
34+
match name:
35+
case "cosine":
36+
from .noise_schedules import CosineNoiseSchedule
37+
38+
return CosineNoiseSchedule(**params)
39+
case "edm":
40+
from .noise_schedules import EDMNoiseSchedule
41+
42+
return EDMNoiseSchedule(**params)
43+
case other:
44+
raise ValueError(f"Unsupported noise schedule config: '{other}'.")
45+
46+
47+
@find_noise_schedule.register
48+
def _(cls: type, *args, **kwargs):
49+
# Lazily import NoiseSchedule class and compare
50+
from .noise_schedules import NoiseSchedule
51+
52+
if issubclass(cls, NoiseSchedule):
53+
return cls(*args, **kwargs)
54+
raise TypeError(f"Expected subclass of NoiseSchedule, got {cls}")
55+
56+
57+
@find_noise_schedule.register
58+
def _(schedule: type, *args, **kwargs):
59+
return schedule

bayesflow/experimental/noise_schedules.py renamed to bayesflow/experimental/diffusion_model/noise_schedules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from bayesflow.utils.serialization import deserialize, serializable
99

1010

11-
@serializable("bayesflow.experimental")
11+
# disable module check, use potential module after moving from experimental
12+
@serializable("bayesflow.networks", disable_module_check=True)
1213
class NoiseSchedule(ABC):
1314
r"""Noise schedule for diffusion models. We follow the notation from [1].
1415
@@ -227,7 +228,8 @@ def from_config(cls, config, custom_objects=None):
227228
return cls(**deserialize(config, custom_objects=custom_objects))
228229

229230

230-
@serializable("bayesflow.experimental")
231+
# disable module check, use potential module after moving from experimental
232+
@serializable("bayesflow.networks", disable_module_check=True)
231233
class EDMNoiseSchedule(NoiseSchedule):
232234
"""EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
233235
This should be used with the F-prediction type in the diffusion model.

bayesflow/utils/dispatch/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@
66
from .find_inference_network import find_inference_network
77
from .find_summary_network import find_summary_network
88
from .find_distribution import find_distribution
9-
from .find_noise_schedule import find_noise_schedule
Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +0,0 @@
1-
from functools import singledispatch
2-
3-
4-
@singledispatch
5-
def find_noise_schedule(arg, *args, **kwargs):
6-
raise TypeError(f"Unknown noise schedule: {arg!r}")
7-
8-
9-
@find_noise_schedule.register
10-
def _(name: str, *args, **kwargs):
11-
match name.lower():
12-
case "cosine":
13-
from bayesflow.experimental.noise_schedules import CosineNoiseSchedule
14-
15-
return CosineNoiseSchedule()
16-
case "edm":
17-
from bayesflow.experimental.noise_schedules import EDMNoiseSchedule
18-
19-
return EDMNoiseSchedule()
20-
case other:
21-
raise ValueError(f"Unsupported noise schedule name: '{other}'.")
22-
23-
24-
@find_noise_schedule.register
25-
def _(config: dict, *args, **kwargs):
26-
name = config.get("type", "").lower()
27-
params = {k: v for k, v in config.items() if k != "type"}
28-
match name:
29-
case "cosine":
30-
from bayesflow.experimental.noise_schedules import CosineNoiseSchedule
31-
32-
return CosineNoiseSchedule(**params)
33-
case "edm":
34-
from bayesflow.experimental.noise_schedules import EDMNoiseSchedule
35-
36-
return EDMNoiseSchedule(**params)
37-
case other:
38-
raise ValueError(f"Unsupported noise schedule config: '{other}'.")
39-
40-
41-
@find_noise_schedule.register
42-
def _(cls: type, *args, **kwargs):
43-
# Lazily import NoiseSchedule class and compare
44-
from bayesflow.experimental.noise_schedules import NoiseSchedule
45-
46-
if issubclass(cls, NoiseSchedule):
47-
return cls(*args, **kwargs)
48-
raise TypeError(f"Expected subclass of NoiseSchedule, got {cls}")
49-
50-
51-
@find_noise_schedule.register
52-
def _(schedule: type, *args, **kwargs):
53-
return schedule

0 commit comments

Comments
 (0)