Skip to content

Commit 1ac9bff

Browse files
committed
reorganize
1 parent caa2d67 commit 1ac9bff

File tree

1 file changed

+59
-12
lines changed

1 file changed

+59
-12
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,64 @@ def compositional_velocity(
593593
training : bool, optional
594594
Whether in training mode
595595
596+
Returns
597+
-------
598+
Tensor
599+
Compositional velocity of same shape as input xz
600+
"""
601+
# Calculate standard noise schedule components
602+
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
603+
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
604+
605+
compositional_score = self.compositional_score(
606+
xz=xz,
607+
time=time,
608+
conditions=conditions,
609+
compute_prior_score=compute_prior_score,
610+
mini_batch_size=mini_batch_size,
611+
training=training,
612+
)
613+
614+
# Compute velocity using standard drift-diffusion formulation
615+
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training)
616+
617+
if stochastic_solver:
618+
# SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW
619+
velocity = f - g_squared * compositional_score
620+
else:
621+
# ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt
622+
velocity = f - 0.5 * g_squared * compositional_score
623+
624+
return velocity
625+
626+
def compositional_score(
627+
self,
628+
xz: Tensor,
629+
time: float | Tensor,
630+
conditions: Tensor,
631+
compute_prior_score: Callable[[Tensor], Tensor],
632+
mini_batch_size: int | None = None,
633+
training: bool = False,
634+
) -> Tensor:
635+
"""
636+
Computes the compositional score for multiple datasets using the formula:
637+
s_ψ(θ,t,Y) = (1-n)(1-t) ∇_θ log p(θ) + Σᵢ₌₁ⁿ s_ψ(θ,t,yᵢ)
638+
639+
Parameters
640+
----------
641+
xz : Tensor
642+
The current state of the latent variable, shape (n_datasets, n_compositional, ...)
643+
time : float or Tensor
644+
Time step for the diffusion process
645+
conditions : Tensor
646+
Conditional inputs with compositional structure (n_datasets, n_compositional, ...)
647+
compute_prior_score: Callable
648+
Function to compute the prior score ∇_θ log p(θ).
649+
mini_batch_size : int or None
650+
Mini batch size for computing individual scores. If None, use all conditions.
651+
training : bool, optional
652+
Whether in training mode
653+
596654
Returns
597655
-------
598656
Tensor
@@ -631,18 +689,7 @@ def compositional_velocity(
631689
# Combined score
632690
time_tensor = ops.cast(time, dtype=ops.dtype(xz))
633691
compositional_score = self.compositional_bridge(time_tensor) * (weighted_prior_score + summed_individual_scores)
634-
635-
# Compute velocity using standard drift-diffusion formulation
636-
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training)
637-
638-
if stochastic_solver:
639-
# SDE: dz = [f(z,t) - g(t)² * score(z,t)] dt + g(t) dW
640-
velocity = f - g_squared * compositional_score
641-
else:
642-
# ODE: dz = [f(z,t) - 0.5 * g(t)² * score(z,t)] dt
643-
velocity = f - 0.5 * g_squared * compositional_score
644-
645-
return velocity
692+
return compositional_score
646693

647694
def _compute_individual_scores(
648695
self,

0 commit comments

Comments
 (0)