@@ -799,21 +799,21 @@ def _inverse_compositional(
799799 """
800800 Inverse pass for compositional diffusion sampling.
801801 """
802- integrate_kwargs = {"start_time" : 1.0 , "stop_time" : 0.0 }
802+ n_compositional = ops .shape (conditions )[1 ]
803+ integrate_kwargs = {"start_time" : 1.0 , "stop_time" : 0.0 , "corrector_steps" : 1 }
803804 integrate_kwargs = integrate_kwargs | self .integrate_kwargs
804805 integrate_kwargs = integrate_kwargs | kwargs
805- mini_batch_size = integrate_kwargs .pop ("mini_batch_size" , None )
806-
807- if mini_batch_size is not None :
808- # if backend is jax, mini batching does not work
809- if keras .backend .backend () == "jax" :
806+ if keras .backend .backend () == "jax" :
807+ mini_batch_size = integrate_kwargs .pop ("mini_batch_size" , None )
808+ if mini_batch_size is not None :
810809 raise ValueError (
811810 "Mini batching is not supported with JAX backend. Set mini_batch_size to None "
812811 "or use another backend."
813812 )
813+ else :
814+ mini_batch_size = integrate_kwargs .get ("mini_batch_size" , int (n_compositional * 0.1 ))
814815
815816 # x is sampled from a normal distribution, must be scaled with var 1/n_compositional
816- n_compositional = ops .shape (conditions )[1 ]
817817 scale_latent = n_compositional * self .compositional_bridge (ops .ones (1 ))
818818 z = z / ops .sqrt (ops .cast (scale_latent , dtype = ops .dtype (z )))
819819
0 commit comments