Skip to content

Commit 3ff135d

Browse files
committed
allow exploding variance type in EDM schedule
1 parent 4e47cc4 commit 3ff135d

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

bayesflow/networks/diffusion_model/schedules/edm_noise_schedule.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ class EDMNoiseSchedule(NoiseSchedule):
1818
generative models. Advances in Neural Information Processing Systems, 35, 26565-26577.
1919
"""
2020

21-
def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max: float = 80.0):
21+
def __init__(
22+
self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max: float = 80.0, variance_type="preserving"
23+
):
2224
"""
2325
Initialize the EDM noise schedule.
2426
@@ -31,8 +33,11 @@ def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max:
3133
The minimum noise level. Only relevant for sampling. Default is 1e-4.
3234
sigma_max : float, optional
3335
The maximum noise level. Only relevant for sampling. Default is 80.0.
36+
variance_type : str, optional
37+
The type of variance to use. One of "preserving", or "exploding". Default is "preserving". Original EDM
38+
paper uses "exploding".
3439
"""
35-
super().__init__(name="edm_noise_schedule", variance_type="preserving")
40+
super().__init__(name="edm_noise_schedule", variance_type=variance_type)
3641
self.sigma_data = sigma_data
3742
# training settings
3843
self.p_mean = -1.2

0 commit comments

Comments
 (0)