Skip to content

Commit ef55b8f

Browse files
authored
fix: 🐛 Fix the torch.median() on apple gpu (#94)
Closes: #1
1 parent c5f1c0a commit ef55b8f

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

lib_layerdiffusion/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,16 @@ def estimate_augmented(self, pixel, latent):
299299
result += [eps]
300300

301301
result = torch.stack(result, dim=0)
302-
median = torch.median(result, dim=0).values
302+
if self.load_device == torch.device("mps"):
303+
'''
304+
In case that apple silicon devices would crash when calling torch.median() on tensors
305+
in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
306+
and then move the result back to gpu.
307+
'''
308+
median = torch.median(result.cpu(), dim=0).values
309+
median = median.to(device=self.load_device, dtype=self.dtype)
310+
else:
311+
median = torch.median(result, dim=0).values
303312
return median
304313

305314
@torch.no_grad()

0 commit comments

Comments
 (0)