@@ -105,7 +105,7 @@ def __init__(
105105 self ._prediction_type = prediction_type
106106 self ._loss_type = loss_type
107107
108- self . schedule_kwargs = schedule_kwargs or {}
108+ schedule_kwargs = schedule_kwargs or {}
109109 self .noise_schedule = find_noise_schedule (noise_schedule , ** self .schedule_kwargs )
110110 self .noise_schedule .validate ()
111111
@@ -148,7 +148,6 @@ def get_config(self):
148148 "noise_schedule" : self .noise_schedule ,
149149 "prediction_type" : self ._prediction_type ,
150150 "loss_type" : self ._loss_type ,
151- "schedule_kwargs" : self .schedule_kwargs ,
152151 "integrate_kwargs" : self .integrate_kwargs ,
153152 }
154153 return base_config | serialize (config )
@@ -194,8 +193,9 @@ def convert_prediction_to_x(
194193 return x1 * z + x2 * pred
195194 elif self ._prediction_type == "x" :
196195 return pred
197- else :
196+ elif self . _prediction_type == "score" :
198197 return (z + sigma_t ** 2 * pred ) / alpha_t
198+ raise ValueError (f"Unknown prediction type { self ._prediction_type } ." )
199199
200200 def velocity (
201201 self ,
@@ -320,12 +320,9 @@ def _forward(
320320 training : bool = False ,
321321 ** kwargs ,
322322 ) -> Tensor | tuple [Tensor , Tensor ]:
323- integrate_kwargs = {
324- ** self .integrate_kwargs ,
325- "start_time" : kwargs .pop ("start_time" , 0.0 ),
326- "stop_time" : kwargs .pop ("stop_time" , 1.0 ),
327- ** kwargs ,
328- }
323+ integrate_kwargs = {"start_time" : 0.0 , "stop_time" : 1.0 }
324+ integrate_kwargs = integrate_kwargs | self .integrate_kwargs
325+ integrate_kwargs = integrate_kwargs | kwargs
329326
330327 if integrate_kwargs ["method" ] == "euler_maruyama" :
331328 raise ValueError ("Stochastic methods are not supported for forward integration." )
@@ -373,12 +370,9 @@ def _inverse(
373370 training : bool = False ,
374371 ** kwargs ,
375372 ) -> Tensor | tuple [Tensor , Tensor ]:
376- integrate_kwargs = {
377- ** self .integrate_kwargs ,
378- "start_time" : kwargs .pop ("start_time" , 1.0 ),
379- "stop_time" : kwargs .pop ("stop_time" , 0.0 ),
380- ** kwargs ,
381- }
373+ integrate_kwargs = {"start_time" : 1.0 , "stop_time" : 0.0 }
374+ integrate_kwargs = integrate_kwargs | self .integrate_kwargs
375+ integrate_kwargs = integrate_kwargs | kwargs
382376 if density :
383377 if integrate_kwargs ["method" ] == "euler_maruyama" :
384378 raise ValueError ("Stochastic methods are not supported for density computation." )
0 commit comments