File tree Expand file tree Collapse file tree 2 files changed +8
-15
lines changed Expand file tree Collapse file tree 2 files changed +8
-15
lines changed Original file line number Diff line number Diff line change @@ -66,24 +66,15 @@ def enable_tf32():
66
66
67
67
68
68
def randn (seed , shape ):
69
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
70
- if device .type == 'mps' :
71
- generator = torch .Generator (device = cpu )
72
- generator .manual_seed (seed )
73
- noise = torch .randn (shape , generator = generator , device = cpu ).to (device )
74
- return noise
75
-
76
69
torch .manual_seed (seed )
70
+ if device .type == 'mps' :
71
+ return torch .randn (shape , device = cpu ).to (device )
77
72
return torch .randn (shape , device = device )
78
73
79
74
80
75
def randn_without_seed (shape ):
81
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
82
76
if device .type == 'mps' :
83
- generator = torch .Generator (device = cpu )
84
- noise = torch .randn (shape , generator = generator , device = cpu ).to (device )
85
- return noise
86
-
77
+ return torch .randn (shape , device = cpu ).to (device )
87
78
return torch .randn (shape , device = device )
88
79
89
80
Original file line number Diff line number Diff line change @@ -365,7 +365,10 @@ def randn_like(self, x):
365
365
if noise .shape == x .shape :
366
366
return noise
367
367
368
- return torch .randn_like (x )
368
+ if x .device .type == 'mps' :
369
+ return torch .randn_like (x , device = devices .cpu ).to (x .device )
370
+ else :
371
+ return torch .randn_like (x )
369
372
370
373
371
374
# MPS fix for randn in torchsde
@@ -429,8 +432,7 @@ def initialize(self, p):
429
432
self .model_wrap .step = 0
430
433
self .eta = p .eta or opts .eta_ancestral
431
434
432
- if self .sampler_noises is not None :
433
- k_diffusion .sampling .torch = TorchHijack (self .sampler_noises )
435
+ k_diffusion .sampling .torch = TorchHijack (self .sampler_noises if self .sampler_noises is not None else [])
434
436
435
437
extra_params_kwargs = {}
436
438
for param_name in self .extra_params :
You can’t perform that action at this time.
0 commit comments