Skip to content

Commit d2a36a8

Browse files
committed
better compositional_bridge
1 parent b2991d1 commit d2a36a8

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import keras
55
from keras import ops
6+
import numpy as np
67

78
from ..inference_network import InferenceNetwork
89
from bayesflow.types import Tensor, Shape
@@ -600,10 +601,10 @@ def compute_metrics(
600601
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
601602
return base_metrics | {"loss": loss}
602603

603-
@staticmethod
604-
def compositional_bridge(time: Tensor) -> Tensor:
604+
def compositional_bridge(self, time: Tensor) -> Tensor:
605605
"""
606-
Bridge function for compositional diffusion. In the simplest case, this is just 1.
606+
Bridge function for compositional diffusion. In the simplest case, this is just 1 if d0 == d1.
607+
Otherwise, it can be used to scale the compositional score over time.
607608
608609
Parameters
609610
----------
@@ -616,7 +617,7 @@ def compositional_bridge(time: Tensor) -> Tensor:
616617
Bridge function value with same shape as time.
617618
618619
"""
619-
return ops.ones_like(time)
620+
return ops.exp(-np.log(self.compositional_d0 / self.compositional_d1) * time)
620621

621622
def compositional_velocity(
622623
self,
@@ -812,6 +813,8 @@ def _inverse_compositional(
812813
)
813814
else:
814815
mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1))
816+
self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0))
817+
self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0))
815818

816819
# x is sampled from a normal distribution, must be scaled with var 1/n_compositional
817820
scale_latent = n_compositional * self.compositional_bridge(ops.ones(1))

0 commit comments

Comments
 (0)