Skip to content

Commit 0781032

Browse files
committed
fix serializable
1 parent b5d6f0f commit 0781032

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from bayesflow.utils.serialization import serialize, deserialize, serializable
2222

2323

24-
@serializable
24+
@serializable("bayesflow.experimental")
2525
class DiffusionModel(InferenceNetwork):
2626
"""Diffusion Model as described in this overview paper [1].
2727
@@ -49,7 +49,7 @@ def __init__(
4949
*,
5050
subnet: str | type = "mlp",
5151
integrate_kwargs: dict[str, any] = None,
52-
noise_schedule: Literal["edm", "cosine"] | type = "edm",
52+
noise_schedule: Literal["edm", "cosine"] | dict | type = "edm",
5353
prediction_type: Literal["velocity", "noise", "F"] = "F",
5454
**kwargs,
5555
):
@@ -69,8 +69,10 @@ def __init__(
6969
callable network. Default is "mlp".
7070
integrate_kwargs : dict[str, any], optional
7171
Additional keyword arguments for the integration process. Default is None.
72-
noise_schedule : Literal['edm', 'cosine'] or type, optional
72+
noise_schedule : Literal['edm', 'cosine'], dict or type, optional
7373
The noise schedule used for the diffusion process. Can be "cosine" or "edm" or a custom noise schedule.
74+
You can also pass a dictionary with the configuration for the noise schedule, e.g.,
75+
{'type': cosine, 's_shift_cosine': 1.0}
7476
Default is "edm".
7577
prediction_type: Literal['velocity', 'noise', 'F'], optional
7678
The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).

bayesflow/experimental/noise_schedules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bayesflow.utils.serialization import deserialize, serializable
99

1010

11-
@serializable
11+
@serializable("bayesflow.experimental")
1212
class NoiseSchedule(ABC):
1313
r"""Noise schedule for diffusion models. We follow the notation from [1].
1414
@@ -155,7 +155,7 @@ def validate(self):
155155
raise ValueError("dt/t log_snr(1) must be finite.")
156156

157157

158-
@serializable
158+
@serializable("bayesflow.experimental")
159159
class CosineNoiseSchedule(NoiseSchedule):
160160
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
161161
For images, use s_shift_cosine = log(base_resolution / d), where d is the used resolution of the image.
@@ -209,7 +209,7 @@ def from_config(cls, config, custom_objects=None):
209209
return cls(**deserialize(config, custom_objects=custom_objects))
210210

211211

212-
@serializable
212+
@serializable("bayesflow.experimental")
213213
class EDMNoiseSchedule(NoiseSchedule):
214214
"""EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
215215
This should be used with the F-prediction type in the diffusion model.

0 commit comments

Comments
 (0)