@@ -246,6 +246,55 @@ def _apply_subnet(
246246 else :
247247 return self .subnet (x = xz , t = log_snr , conditions = conditions , training = training )
248248
249+ def score (
250+ self ,
251+ xz : Tensor ,
252+ time : float | Tensor = None ,
253+ log_snr_t : Tensor = None ,
254+ conditions : Tensor = None ,
255+ training : bool = False ,
256+ ) -> Tensor :
257+ """
258+ Computes the score of the target or latent variable `xz`.
259+
260+ Parameters
261+ ----------
262+ xz : Tensor
263+ The current state of the latent variable `z`, typically of shape (..., D),
264+ where D is the dimensionality of the latent space.
265+ time : float or Tensor
266+ Scalar or tensor representing the time (or noise level) at which the velocity
267+ should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided.
268+ log_snr_t : Tensor
269+ The log signal-to-noise ratio at time `t`. If None, time must be provided.
270+ conditions : Tensor, optional
271+ Conditional inputs to the network, such as conditioning variables
272+ or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
273+ training : bool, optional
274+ Whether the model is in training mode. Affects behavior of dropout, batch norm,
275+ or other stochastic layers. Default is False.
276+
277+ Returns
278+ -------
279+ Tensor
280+ The velocity tensor of the same shape as `xz`, representing the right-hand
281+ side of the SDE or ODE at the given `time`.
282+ """
283+ if log_snr_t is None :
284+ log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
285+ log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
286+ alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
287+
288+ subnet_out = self ._apply_subnet (
289+ xz , self ._transform_log_snr (log_snr_t ), conditions = conditions , training = training
290+ )
291+ pred = self .output_projector (subnet_out , training = training )
292+
293+ x_pred = self .convert_prediction_to_x (pred = pred , z = xz , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t )
294+
295+ score = (alpha_t * x_pred - xz ) / ops .square (sigma_t )
296+ return score
297+
249298 def velocity (
250299 self ,
251300 xz : Tensor ,
@@ -282,19 +331,10 @@ def velocity(
282331 The velocity tensor of the same shape as `xz`, representing the right-hand
283332 side of the SDE or ODE at the given `time`.
284333 """
285- # calculate the current noise level and transform into correct shape
286334 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
287335 log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
288- alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
289-
290- subnet_out = self ._apply_subnet (
291- xz , self ._transform_log_snr (log_snr_t ), conditions = conditions , training = training
292- )
293- pred = self .output_projector (subnet_out , training = training )
294336
295- x_pred = self .convert_prediction_to_x (pred = pred , z = xz , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t )
296-
297- score = (alpha_t * x_pred - xz ) / ops .square (sigma_t )
337+ score = self .score (xz , log_snr_t = log_snr_t , conditions = conditions , training = training )
298338
299339 # compute velocity f, g of the SDE or ODE
300340 f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz , training = training )
@@ -452,9 +492,24 @@ def deltas(time, xz):
452492 def diffusion (time , xz ):
453493 return {"xz" : self .diffusion_term (xz , time = time , training = training )}
454494
495+ score_fn = None
496+ if "corrector_steps" in integrate_kwargs :
497+ if integrate_kwargs ["corrector_steps" ] > 0 :
498+
499+ def score_fn (time , xz ):
500+ return {
501+ "xz" : self .score (
502+ xz ,
503+ time = time ,
504+ conditions = conditions ,
505+ training = training ,
506+ )
507+ }
508+
455509 state = integrate_stochastic (
456510 drift_fn = deltas ,
457511 diffusion_fn = diffusion ,
512+ score_fn = score_fn ,
458513 state = state ,
459514 seed = self .seed_generator ,
460515 ** integrate_kwargs ,
@@ -836,11 +891,11 @@ def deltas(time, xz):
836891 def diffusion (time , xz ):
837892 return {"xz" : self .diffusion_term (xz , time = time , training = training )}
838893
839- scores = None
894+ score_fn = None
840895 if "corrector_steps" in integrate_kwargs :
841896 if integrate_kwargs ["corrector_steps" ] > 0 :
842897
843- def scores (time , xz ):
898+ def score_fn (time , xz ):
844899 return {
845900 "xz" : self .compositional_score (
846901 xz ,
@@ -855,7 +910,7 @@ def scores(time, xz):
855910 state = integrate_stochastic (
856911 drift_fn = deltas ,
857912 diffusion_fn = diffusion ,
858- score_fn = scores ,
913+ score_fn = score_fn ,
859914 state = state ,
860915 seed = self .seed_generator ,
861916 ** integrate_kwargs ,
0 commit comments