Skip to content

Commit 55c18e2

Browse files
committed
fix type hint
1 parent 2aa0c02 commit 55c18e2

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

bayesflow/networks/diffusion_model/schedules/edm_noise_schedule.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from typing import Literal
23

34
from keras import ops
45

@@ -19,7 +20,11 @@ class EDMNoiseSchedule(NoiseSchedule):
1920
"""
2021

2122
def __init__(
22-
self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max: float = 80.0, variance_type="preserving"
23+
self,
24+
sigma_data: float = 1.0,
25+
sigma_min: float = 1e-4,
26+
sigma_max: float = 80.0,
27+
variance_type: Literal["preserving", "exploding"] = "preserving",
2328
):
2429
"""
2530
Initialize the EDM noise schedule.
@@ -33,9 +38,8 @@ def __init__(
3338
The minimum noise level. Only relevant for sampling. Default is 1e-4.
3439
sigma_max : float, optional
3540
The maximum noise level. Only relevant for sampling. Default is 80.0.
36-
variance_type : str, optional
37-
The type of variance to use. One of "preserving", or "exploding". Default is "preserving". Original EDM
38-
paper uses "exploding".
41+
variance_type : Literal["preserving", "exploding"], optional
42+
The type of variance to use. Default is "preserving". Original EDM paper uses "exploding".
3943
"""
4044
super().__init__(name="edm_noise_schedule", variance_type=variance_type)
4145
self.sigma_data = sigma_data

0 commit comments

Comments
 (0)