33
44import keras
55from keras import ops
6+ import numpy as np
67
78from ..inference_network import InferenceNetwork
89from bayesflow .types import Tensor , Shape
@@ -600,10 +601,10 @@ def compute_metrics(
600601 base_metrics = super ().compute_metrics (x , conditions = conditions , sample_weight = sample_weight , stage = stage )
601602 return base_metrics | {"loss" : loss }
602603
603- @staticmethod
604- def compositional_bridge (time : Tensor ) -> Tensor :
604+ def compositional_bridge (self , time : Tensor ) -> Tensor :
605605 """
606- Bridge function for compositional diffusion. In the simplest case, this is just 1.
606+ Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1.
607+ Otherwise, it can be used to scale the compositional score over time.
607608
608609 Parameters
609610 ----------
@@ -616,7 +617,7 @@ def compositional_bridge(time: Tensor) -> Tensor:
616617 Bridge function value with same shape as time.
617618
618619 """
619- return ops .ones_like ( time )
620+ return ops .exp ( - np . log ( self . compositional_d0 / self . compositional_d1 ) * time )
620621
621622 def compositional_velocity (
622623 self ,
@@ -812,6 +813,8 @@ def _inverse_compositional(
812813 )
813814 else :
814815 mini_batch_size = integrate_kwargs .get ("mini_batch_size" , int (n_compositional * 0.1 ))
816+ self .compositional_d0 = float (integrate_kwargs .pop ("compositional_d0" , 1.0 ))
817+ self .compositional_d1 = float (integrate_kwargs .pop ("compositional_d1" , 1.0 ))
815818
816819 # x is sampled from a normal distribution, must be scaled with var 1/n_compositional
817820 scale_latent = n_compositional * self .compositional_bridge (ops .ones (1 ))
0 commit comments