Skip to content

Commit b2ef755

Browse files
committed
fix integrate_kwargs
1 parent 0ff960f commit b2ef755

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def compositional_bridge(self, time: Tensor) -> Tensor:
617617
Bridge function value with same shape as time.
618618
619619
"""
620-
return ops.exp(-np.log(self.compositional_d0 / self.compositional_d1) * time)
620+
return ops.exp(-np.log(self.compositional_bridge_d0 / self.compositional_bridge_d1) * time)
621621

622622
def compositional_velocity(
623623
self,
@@ -813,8 +813,8 @@ def _inverse_compositional(
813813
)
814814
else:
815815
mini_batch_size = integrate_kwargs.pop("mini_batch_size", int(n_compositional * 0.1))
816-
self.compositional_d0 = float(integrate_kwargs.pop("compositional_d0", 1.0))
817-
self.compositional_d1 = float(integrate_kwargs.pop("compositional_d1", 1.0))
816+
self.compositional_bridge_d0 = float(integrate_kwargs.pop("compositional_bridge_d0", 1.0))
817+
self.compositional_bridge_d1 = float(integrate_kwargs.pop("compositional_bridge_d1", 1.0))
818818

819819
# x is sampled from a normal distribution, must be scaled with var 1/n_compositional
820820
scale_latent = n_compositional * self.compositional_bridge(ops.ones(1))
@@ -893,6 +893,7 @@ def score_fn(time, xz):
893893
**integrate_kwargs,
894894
)
895895
else:
896+
integrate_kwargs.pop("corrector_steps", None)
896897

897898
def deltas(time, xz):
898899
return {

0 commit comments

Comments
 (0)