@@ -528,9 +528,9 @@ def velocity(
528528 self ,
529529 xz : Tensor ,
530530 time : float | Tensor ,
531+ stochastic_solver : bool ,
531532 conditions : Tensor = None ,
532533 training : bool = False ,
533- stochastic_solver : bool = False ,
534534 clip_x : bool = False ,
535535 ) -> Tensor :
536536 # calculate the current noise level and transform into correct shape
@@ -583,7 +583,7 @@ def _velocity_trace(
583583 training : bool = False ,
584584 ) -> (Tensor , Tensor ):
585585 def f (x ):
586- return self .velocity (x , time = time , conditions = conditions , training = training )
586+ return self .velocity (x , time = time , stochastic_solver = False , conditions = conditions , training = training )
587587
588588 v , trace = jacobian_trace (f , xz , max_steps = max_steps , seed = self .seed_generator , return_output = True )
589589
@@ -630,7 +630,9 @@ def deltas(time, xz):
630630 return z , log_density
631631
632632 def deltas (time , xz ):
633- return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
633+ return {
634+ "xz" : self .velocity (xz , time = time , stochastic_solver = False , conditions = conditions , training = training )
635+ }
634636
635637 state = {"xz" : x }
636638 state = integrate (
@@ -676,12 +678,14 @@ def deltas(time, xz):
676678
677679 return x , log_density
678680
679- def deltas (time , xz ):
680- return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
681-
682681 state = {"xz" : z }
683682 if integrate_kwargs ["method" ] == "euler_maruyama" :
684683
684+ def deltas (time , xz ):
685+ return {
686+ "xz" : self .velocity (xz , time = time , stochastic_solver = True , conditions = conditions , training = training )
687+ }
688+
685689 def diffusion (time , xz ):
686690 return {"xz" : self .compute_diffusion_term (xz , time = time , training = training )}
687691
@@ -692,6 +696,14 @@ def diffusion(time, xz):
692696 ** integrate_kwargs ,
693697 )
694698 else :
699+
700+ def deltas (time , xz ):
701+ return {
702+ "xz" : self .velocity (
703+ xz , time = time , stochastic_solver = False , conditions = conditions , training = training
704+ )
705+ }
706+
695707 state = integrate (
696708 deltas ,
697709 state ,
@@ -709,6 +721,7 @@ def compute_metrics(
709721 stage : str = "training" ,
710722 ) -> dict [str , Tensor ]:
711723 training = stage == "training"
724+ noise_schedule_training_stage = stage == "training" or stage == "validation"
712725 if not self .built :
713726 xz_shape = keras .ops .shape (x )
714727 conditions_shape = None if conditions is None else keras .ops .shape (conditions )
@@ -723,8 +736,10 @@ def compute_metrics(
723736 # t = keras.ops.cast(i, keras.ops.dtype(x)) / keras.ops.cast(self._timesteps, keras.ops.dtype(x))
724737
725738 # calculate the noise level
726- log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t , training = training ), x )
727- alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t , training = training )
739+ log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t , training = noise_schedule_training_stage ), x )
740+ alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (
741+ log_snr_t = log_snr_t , training = noise_schedule_training_stage
742+ )
728743
729744 # generate noise vector
730745 eps_t = keras .random .normal (ops .shape (x ), dtype = ops .dtype (x ), seed = self .seed_generator )
0 commit comments