|
6 | 6 | from PIL import Image
|
7 | 7 | import inspect
|
8 | 8 | import k_diffusion.sampling
|
| 9 | +import torchsde._brownian.brownian_interval |
9 | 10 | import ldm.models.diffusion.ddim
|
10 | 11 | import ldm.models.diffusion.plms
|
11 | 12 | from modules import prompt_parser, devices, processing, images
|
@@ -364,7 +365,23 @@ def randn_like(self, x):
|
364 | 365 | if noise.shape == x.shape:
|
365 | 366 | return noise
|
366 | 367 |
|
367 |
| - return torch.randn_like(x) |
| 368 | + if x.device.type == 'mps': |
| 369 | + return torch.randn_like(x, device=devices.cpu).to(x.device) |
| 370 | + else: |
| 371 | + return torch.randn_like(x) |
| 372 | + |
| 373 | + |
| 374 | +# MPS fix for randn in torchsde |
| 375 | +def torchsde_randn(size, dtype, device, seed): |
| 376 | + if device.type == 'mps': |
| 377 | + generator = torch.Generator(devices.cpu).manual_seed(int(seed)) |
| 378 | + return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) |
| 379 | + else: |
| 380 | + generator = torch.Generator(device).manual_seed(int(seed)) |
| 381 | + return torch.randn(size, dtype=dtype, device=device, generator=generator) |
| 382 | + |
| 383 | + |
| 384 | +torchsde._brownian.brownian_interval._randn = torchsde_randn |
368 | 385 |
|
369 | 386 |
|
370 | 387 | class KDiffusionSampler:
|
@@ -415,8 +432,7 @@ def initialize(self, p):
|
415 | 432 | self.model_wrap.step = 0
|
416 | 433 | self.eta = p.eta or opts.eta_ancestral
|
417 | 434 |
|
418 |
| - if self.sampler_noises is not None: |
419 |
| - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises) |
| 435 | + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) |
420 | 436 |
|
421 | 437 | extra_params_kwargs = {}
|
422 | 438 | for param_name in self.extra_params:
|
|
0 commit comments