diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index bea6e5e07543..7f2dd081577b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -38,7 +38,20 @@ def __init__(self, x, t0, t1, seed=None, **kwargs): except TypeError: seed = [seed] self.batched = False - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + self.trees = [ + torchsde.BrownianInterval( + t0=t0, + t1=t1, + size=w0.shape, + dtype=w0.dtype, + device=w0.device, + entropy=s, + tol=1e-6, + pool_size=24, + halfway_tree=True, + ) + for s in seed + ] @staticmethod def sort(a, b):