Skip to content

Commit b19827f

Browse files
dianyoyiyixuxu
andauthored
Migrate the BrownianTree to BrownianInterval in DPM solver (#9335)
migrate the BrownianTree to BrownianInterval Co-authored-by: YiYi Xu <[email protected]>
1 parent c002731 commit b19827f

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_sde.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,20 @@ def __init__(self, x, t0, t1, seed=None, **kwargs):
3838
except TypeError:
3939
seed = [seed]
4040
self.batched = False
41-
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
41+
self.trees = [
42+
torchsde.BrownianInterval(
43+
t0=t0,
44+
t1=t1,
45+
size=w0.shape,
46+
dtype=w0.dtype,
47+
device=w0.device,
48+
entropy=s,
49+
tol=1e-6,
50+
pool_size=24,
51+
halfway_tree=True,
52+
)
53+
for s in seed
54+
]
4255

4356
@staticmethod
4457
def sort(a, b):

0 commit comments

Comments
 (0)