@@ -374,7 +374,7 @@ class DiffusionModel(InferenceNetwork):
374374 }
375375
376376 INTEGRATE_DEFAULT_CONFIG = {
377- "method" : "euler" ,
377+ "method" : "euler" , # or euler_maruyama
378378 "steps" : 100 ,
379379 }
380380
@@ -530,6 +530,7 @@ def velocity(
530530 time : float | Tensor ,
531531 conditions : Tensor = None ,
532532 training : bool = False ,
533+ stochastic_solver : bool = False ,
533534 clip_x : bool = False ,
534535 ) -> Tensor :
535536 # calculate the current noise level and transform into correct shape
@@ -549,44 +550,28 @@ def velocity(
549550 # convert x to score
550551 score = (alpha_t * x_pred - xz ) / ops .square (sigma_t )
551552
552- # compute velocity for the ODE depending on the noise schedule
553+ # compute velocity f, g of the SDE or ODE
553554 f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz )
554- # out = f - 0.5 * g_squared * score
555- out = f - g_squared * score
556555
557- # todo: for the SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
556+ if stochastic_solver :
557+ # for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW
558+ out = f - g_squared * score
559+ else :
560+ # for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt
561+ out = f - 0.5 * g_squared * score
562+
558563 return out
559564
560- def velocity2 (
565+ def compute_diffusion_term (
561566 self ,
562567 xz : Tensor ,
563568 time : float | Tensor ,
564- conditions : Tensor = None ,
565569 training : bool = False ,
566- clip_x : bool = False ,
567570 ) -> Tensor :
568571 # calculate the current noise level and transform into correct shape
569572 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
570573 log_snr_t = keras .ops .broadcast_to (log_snr_t , keras .ops .shape (xz )[:- 1 ] + (1 ,))
571- # alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t, training=training)
572-
573- # if conditions is None:
574- # xtc = keras.ops.concatenate([xz, log_snr_t], axis=-1)
575- # else:
576- # xtc = keras.ops.concatenate([xz, log_snr_t, conditions], axis=-1)
577- # pred = self.output_projector(self.subnet(xtc, training=training), training=training)
578-
579- # x_pred = self.convert_prediction_to_x(
580- # pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, clip_x=clip_x
581- # )
582- # convert x to score
583- # score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
584-
585- # compute velocity for the ODE depending on the noise schedule
586- f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz )
587- # out = f - 0.5 * g_squared * score
588- # out = f - g_squared * score
589-
574+ g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t )
590575 return ops .sqrt (g_squared )
591576
592577 def _velocity_trace (
@@ -620,6 +605,9 @@ def _forward(
620605 | self .integrate_kwargs
621606 | kwargs
622607 )
608+ if integrate_kwargs ["method" ] == "euler_maruyama" :
609+ raise ValueError ("Stoachastic methods are not supported for forward integration." )
610+
623611 if density :
624612
625613 def deltas (time , xz ):
@@ -670,6 +658,8 @@ def _inverse(
670658 | kwargs
671659 )
672660 if density :
661+ if integrate_kwargs ["method" ] == "euler_maruyama" :
662+ raise ValueError ("Stoachastic methods are not supported for density computation." )
673663
674664 def deltas (time , xz ):
675665 v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
@@ -689,21 +679,24 @@ def deltas(time, xz):
689679 def deltas (time , xz ):
690680 return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
691681
692- def diffusion (time , xz ):
693- return {"xz" : self .velocity2 (xz , time = time , conditions = conditions , training = training )}
694-
695682 state = {"xz" : z }
696- # state = integrate(
697- # deltas,
698- # state,
699- # **integrate_kwargs,
700- # )
701- state = integrate_stochastic (
702- deltas ,
703- diffusion ,
704- state ,
705- ** integrate_kwargs ,
706- )
683+ if integrate_kwargs ["method" ] == "euler_maruyama" :
684+
685+ def diffusion (time , xz ):
686+ return {"xz" : self .compute_diffusion_term (xz , time = time , training = training )}
687+
688+ state = integrate_stochastic (
689+ deltas ,
690+ diffusion ,
691+ state ,
692+ ** integrate_kwargs ,
693+ )
694+ else :
695+ state = integrate (
696+ deltas ,
697+ state ,
698+ ** integrate_kwargs ,
699+ )
707700
708701 x = state ["xz" ]
709702 return x
0 commit comments