Skip to content

Commit b2991d1

Browse files
committed
better standard values for compositional
1 parent 922040d commit b2991d1

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -799,21 +799,21 @@ def _inverse_compositional(
799799
"""
800800
Inverse pass for compositional diffusion sampling.
801801
"""
802-
integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0}
802+
n_compositional = ops.shape(conditions)[1]
803+
integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0, "corrector_steps": 1}
803804
integrate_kwargs = integrate_kwargs | self.integrate_kwargs
804805
integrate_kwargs = integrate_kwargs | kwargs
805-
mini_batch_size = integrate_kwargs.pop("mini_batch_size", None)
806-
807-
if mini_batch_size is not None:
808-
# if backend is jax, mini batching does not work
809-
if keras.backend.backend() == "jax":
806+
if keras.backend.backend() == "jax":
807+
mini_batch_size = integrate_kwargs.pop("mini_batch_size", None)
808+
if mini_batch_size is not None:
810809
raise ValueError(
811810
"Mini batching is not supported with JAX backend. Set mini_batch_size to None "
812811
"or use another backend."
813812
)
813+
else:
814+
mini_batch_size = integrate_kwargs.get("mini_batch_size", int(n_compositional * 0.1))
814815

815816
# x is sampled from a normal distribution, must be scaled with var 1/n_compositional
816-
n_compositional = ops.shape(conditions)[1]
817817
scale_latent = n_compositional * self.compositional_bridge(ops.ones(1))
818818
z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z)))
819819

0 commit comments

Comments
 (0)