@@ -593,6 +593,64 @@ def compositional_velocity(
593593 training : bool, optional
594594 Whether in training mode
595595
596+ Returns
597+ -------
598+ Tensor
599+ Compositional velocity of same shape as input xz
600+ """
601+ # Calculate standard noise schedule components
602+ log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
603+ log_snr_t = ops .broadcast_to (log_snr_t , ops .shape (xz )[:- 1 ] + (1 ,))
604+
605+ compositional_score = self .compositional_score (
606+ xz = xz ,
607+ time = time ,
608+ conditions = conditions ,
609+ compute_prior_score = compute_prior_score ,
610+ mini_batch_size = mini_batch_size ,
611+ training = training ,
612+ )
613+
614+ # Compute velocity using standard drift-diffusion formulation
615+ f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz , training = training )
616+
617+ if stochastic_solver :
618+ # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW
619+ velocity = f - g_squared * compositional_score
620+ else :
621+ # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt
622+ velocity = f - 0.5 * g_squared * compositional_score
623+
624+ return velocity
625+
626+ def compositional_score (
627+ self ,
628+ xz : Tensor ,
629+ time : float | Tensor ,
630+ conditions : Tensor ,
631+ compute_prior_score : Callable [[Tensor ], Tensor ],
632+ mini_batch_size : int | None = None ,
633+ training : bool = False ,
634+ ) -> Tensor :
635+ """
636+ Computes the compositional score for multiple datasets using the formula:
637+ s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ)
638+
639+ Parameters
640+ ----------
641+ xz : Tensor
642+ The current state of the latent variable, shape (n_datasets, n_compositional, ...)
643+ time : float or Tensor
644+ Time step for the diffusion process
645+ conditions : Tensor
646+ Conditional inputs with compositional structure (n_datasets, n_compositional, ...)
647+ compute_prior_score: Callable
648+ Function to compute the prior score ∇_θ log p(θ).
649+ mini_batch_size : int or None
650+ Mini batch size for computing individual scores. If None, use all conditions.
651+ training : bool, optional
652+ Whether in training mode
653+
596654 Returns
597655 -------
598656 Tensor
@@ -631,18 +689,7 @@ def compositional_velocity(
631689 # Combined score
632690 time_tensor = ops .cast (time , dtype = ops .dtype (xz ))
633691 compositional_score = self .compositional_bridge (time_tensor ) * (weighted_prior_score + summed_individual_scores )
634-
635- # Compute velocity using standard drift-diffusion formulation
636- f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz , training = training )
637-
638- if stochastic_solver :
639- # SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW
640- velocity = f - g_squared * compositional_score
641- else :
642- # ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt
643- velocity = f - 0.5 * g_squared * compositional_score
644-
645- return velocity
692+ return compositional_score
646693
647694 def _compute_individual_scores (
648695 self ,
0 commit comments