@@ -41,10 +41,10 @@ class NoiseSchedule(ABC):
4141
4242 def __init__ (self , name : str , variance_type : str , weighting : str = None ):
4343 self .name = name
44- self .variance_type = variance_type # 'exploding' or 'preserving'
45- self ._log_snr_min = - 15 # should be set in the subclasses
46- self ._log_snr_max = 15 # should be set in the subclasses
47- self .weighting = weighting
44+ self ._variance_type = variance_type # 'exploding' or 'preserving'
45+ self .log_snr_min = - 15 # should be set in the subclasses
46+ self .log_snr_max = 15 # should be set in the subclasses
47+ self ._weighting = weighting
4848
4949 @abstractmethod
5050 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
@@ -76,12 +76,12 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
7676 beta = self .derivative_log_snr (log_snr_t = log_snr_t , training = training )
7777 if x is None : # return g^2 only
7878 return beta
79- if self .variance_type == "preserving" :
79+ if self ._variance_type == "preserving" :
8080 f = - 0.5 * beta * x
81- elif self .variance_type == "exploding" :
81+ elif self ._variance_type == "exploding" :
8282 f = ops .zeros_like (beta )
8383 else :
84- raise ValueError (f"Unknown variance type: { self .variance_type } " )
84+ raise ValueError (f"Unknown variance type: { self ._variance_type } " )
8585 return f , beta
8686
8787 def get_alpha_sigma (self , log_snr_t : Tensor , training : bool ) -> tuple [Tensor , Tensor ]:
@@ -92,58 +92,58 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
9292 sigma(t) = sqrt(sigmoid(-log_snr_t))
9393 For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
9494 """
95- if self .variance_type == "preserving" :
95+ if self ._variance_type == "preserving" :
9696 # variance preserving schedule
9797 alpha_t = ops .sqrt (ops .sigmoid (log_snr_t ))
9898 sigma_t = ops .sqrt (ops .sigmoid (- log_snr_t ))
99- elif self .variance_type == "exploding" :
99+ elif self ._variance_type == "exploding" :
100100 # variance exploding schedule
101101 alpha_t = ops .ones_like (log_snr_t )
102102 sigma_t = ops .sqrt (ops .exp (- log_snr_t ))
103103 else :
104- raise ValueError (f"Unknown variance type: { self .variance_type } " )
104+ raise ValueError (f"Unknown variance type: { self ._variance_type } " )
105105 return alpha_t , sigma_t
106106
107107 def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
108108 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). Default is 1.
109109 Generally, weighting functions should be defined for a noise prediction loss.
110110 """
111- if self .weighting is None :
111+ if self ._weighting is None :
112112 return ops .ones_like (log_snr_t )
113- elif self .weighting == "sigmoid" :
113+ elif self ._weighting == "sigmoid" :
114114 # sigmoid weighting based on Kingma et al. (2023)
115115 return ops .sigmoid (- log_snr_t + 2 )
116- elif self .weighting == "likelihood_weighting" :
116+ elif self ._weighting == "likelihood_weighting" :
117117 # likelihood weighting based on Song et al. (2021)
118118 g_squared = self .get_drift_diffusion (log_snr_t = log_snr_t )
119119 sigma_t = self .get_alpha_sigma (log_snr_t = log_snr_t , training = True )[1 ]
120120 return g_squared / ops .square (sigma_t )
121121 else :
122- raise ValueError (f"Unknown weighting type: { self .weighting } " )
122+ raise ValueError (f"Unknown weighting type: { self ._weighting } " )
123123
124124 def get_config (self ):
125- return dict (name = self .name , variance_type = self .variance_type )
125+ return dict (name = self .name , variance_type = self ._variance_type )
126126
127127 @classmethod
128128 def from_config (cls , config , custom_objects = None ):
129129 return cls (** deserialize (config , custom_objects = custom_objects ))
130130
131131 def validate (self ):
132132 """Validate the noise schedule."""
133- if self ._log_snr_min >= self ._log_snr_max :
133+ if self .log_snr_min >= self .log_snr_max :
134134 raise ValueError ("min_log_snr must be less than max_log_snr." )
135135 for training in [True , False ]:
136136 if not ops .isfinite (self .get_log_snr (0.0 , training = training )):
137137 raise ValueError ("log_snr(0) must be finite." )
138138 if not ops .isfinite (self .get_log_snr (1.0 , training = training )):
139139 raise ValueError ("log_snr(1) must be finite." )
140- if not ops .isfinite (self .get_t_from_log_snr (self ._log_snr_max , training = training )):
140+ if not ops .isfinite (self .get_t_from_log_snr (self .log_snr_max , training = training )):
141141 raise ValueError ("t(0) must be finite." )
142- if not ops .isfinite (self .get_t_from_log_snr (self ._log_snr_min , training = training )):
142+ if not ops .isfinite (self .get_t_from_log_snr (self .log_snr_min , training = training )):
143143 raise ValueError ("t(1) must be finite." )
144- if not ops .isfinite (self .derivative_log_snr (self ._log_snr_max , training = False )):
144+ if not ops .isfinite (self .derivative_log_snr (self .log_snr_max , training = False )):
145145 raise ValueError ("dt/t log_snr(0) must be finite." )
146- if not ops .isfinite (self .derivative_log_snr (self ._log_snr_min , training = False )):
146+ if not ops .isfinite (self .derivative_log_snr (self .log_snr_min , training = False )):
147147 raise ValueError ("dt/t log_snr(1) must be finite." )
148148
149149
@@ -158,11 +158,11 @@ class LinearNoiseSchedule(NoiseSchedule):
158158
159159 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
160160 super ().__init__ (name = "linear_noise_schedule" , variance_type = "preserving" , weighting = "likelihood_weighting" )
161- self ._log_snr_min = min_log_snr
162- self ._log_snr_max = max_log_snr
161+ self .log_snr_min = min_log_snr
162+ self .log_snr_max = max_log_snr
163163
164- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
165- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
164+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self .log_snr_max , training = True )
165+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self .log_snr_min , training = True )
166166
167167 def _truncated_t (self , t : Tensor ) -> Tensor :
168168 return self ._t_min + (self ._t_max - self ._t_min ) * t
@@ -194,7 +194,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
194194 return - factor * dsnr_dt
195195
196196 def get_config (self ):
197- return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max )
197+ return dict (min_log_snr = self .log_snr_min , max_log_snr = self .log_snr_max )
198198
199199 @classmethod
200200 def from_config (cls , config , custom_objects = None ):
@@ -214,12 +214,11 @@ def __init__(
214214 ):
215215 super ().__init__ (name = "cosine_noise_schedule" , variance_type = "preserving" , weighting = weighting )
216216 self ._s_shift_cosine = s_shift_cosine
217- self ._log_snr_min = min_log_snr
218- self ._log_snr_max = max_log_snr
219- self ._s_shift_cosine = s_shift_cosine
217+ self .log_snr_min = min_log_snr
218+ self .log_snr_max = max_log_snr
220219
221- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
222- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
220+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self .log_snr_max , training = True )
221+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self .log_snr_min , training = True )
223222
224223 def _truncated_t (self , t : Tensor ) -> Tensor :
225224 return self ._t_min + (self ._t_max - self ._t_min ) * t
@@ -250,7 +249,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
250249 return - factor * dsnr_dt
251250
252251 def get_config (self ):
253- return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max , s_shift_cosine = self ._s_shift_cosine )
252+ return dict (min_log_snr = self .log_snr_min , max_log_snr = self .log_snr_max , s_shift_cosine = self ._s_shift_cosine )
254253
255254 @classmethod
256255 def from_config (cls , config , custom_objects = None ):
@@ -278,12 +277,12 @@ def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max:
278277 self .rho = 7
279278
280279 # convert EDM parameters to signal-to-noise ratio formulation
281- self ._log_snr_min = - 2 * ops .log (sigma_max )
282- self ._log_snr_max = - 2 * ops .log (sigma_min )
280+ self .log_snr_min = - 2 * ops .log (sigma_max )
281+ self .log_snr_max = - 2 * ops .log (sigma_min )
283282 # t is not truncated for EDM by definition of the sampling schedule
284283 # training bounds should be set to avoid numerical issues
285- self ._log_snr_min_training = self ._log_snr_min - 1 # one is never sampler during training
286- self ._log_snr_max_training = self ._log_snr_max + 1 # 0 is almost surely never sampled during training
284+ self ._log_snr_min_training = self .log_snr_min - 1 # one is never sampler during training
285+ self ._log_snr_max_training = self .log_snr_max + 1 # 0 is almost surely never sampled during training
287286
288287 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
289288 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -537,9 +536,9 @@ def velocity(
537536 alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t , training = training )
538537
539538 if conditions is None :
540- xtc = ops .concatenate ([xz , log_snr_t ], axis = - 1 )
539+ xtc = ops .concatenate ([xz , self . _transform_log_snr ( log_snr_t ) ], axis = - 1 )
541540 else :
542- xtc = ops .concatenate ([xz , log_snr_t , conditions ], axis = - 1 )
541+ xtc = ops .concatenate ([xz , self . _transform_log_snr ( log_snr_t ) , conditions ], axis = - 1 )
543542 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
544543
545544 x_pred = self .convert_prediction_to_x (
@@ -587,6 +586,16 @@ def f(x):
587586
588587 return v , ops .expand_dims (trace , axis = - 1 )
589588
589+ def _transform_log_snr (self , log_snr : Tensor ) -> Tensor :
590+ """Transform the log_snr to the range [-1, 1] for the diffusion process."""
591+ # Transform the log_snr to the range [-1, 1]
592+ return (
593+ 2
594+ * (log_snr - self .noise_schedule .log_snr_min )
595+ / (self .noise_schedule .log_snr_max - self .noise_schedule .log_snr_min )
596+ - 1
597+ )
598+
590599 def _forward (
591600 self ,
592601 x : Tensor ,
@@ -749,9 +758,9 @@ def compute_metrics(
749758
750759 # calculate output of the network
751760 if conditions is None :
752- xtc = ops .concatenate ([diffused_x , log_snr_t ], axis = - 1 )
761+ xtc = ops .concatenate ([diffused_x , self . _transform_log_snr ( log_snr_t ) ], axis = - 1 )
753762 else :
754- xtc = ops .concatenate ([diffused_x , log_snr_t , conditions ], axis = - 1 )
763+ xtc = ops .concatenate ([diffused_x , self . _transform_log_snr ( log_snr_t ) , conditions ], axis = - 1 )
755764 pred = self .output_projector (self .subnet (xtc , training = training ), training = training )
756765
757766 x_pred = self .convert_prediction_to_x (
0 commit comments