11import math
2+ from typing import Literal
23
34from keras import ops
45
@@ -19,7 +20,11 @@ class EDMNoiseSchedule(NoiseSchedule):
1920 """
2021
2122 def __init__ (
22- self , sigma_data : float = 1.0 , sigma_min : float = 1e-4 , sigma_max : float = 80.0 , variance_type = "preserving"
23+ self ,
24+ sigma_data : float = 1.0 ,
25+ sigma_min : float = 1e-4 ,
26+ sigma_max : float = 80.0 ,
27+ variance_type : Literal ["preserving" , "exploding" ] = "preserving" ,
2328 ):
2429 """
2530 Initialize the EDM noise schedule.
@@ -33,9 +38,8 @@ def __init__(
3338 The minimum noise level. Only relevant for sampling. Default is 1e-4.
3439 sigma_max : float, optional
3540 The maximum noise level. Only relevant for sampling. Default is 80.0.
36- variance_type : str, optional
37- The type of variance to use. One of "preserving", or "exploding". Default is "preserving". Original EDM
38- paper uses "exploding".
41+ variance_type : Literal["preserving", "exploding"], optional
42+ The type of variance to use. Default is "preserving". Original EDM paper uses "exploding".
3943 """
4044 super ().__init__ (name = "edm_noise_schedule" , variance_type = variance_type )
4145 self .sigma_data = sigma_data
0 commit comments