Skip to content

Commit 09bd9fa

Browse files
committed
move autocast device selection to a function
1 parent fa98601 commit 09bd9fa

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

ldm/dream/devices.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,10 @@ def choose_torch_device() -> str:
88
return 'mps'
99
return 'cpu'
1010

11-
11+
def choose_autocast_device(device) -> str:
12+
'''Returns an autocast compatible device from a torch device'''
13+
device_type = device.type # this returns 'mps' on M1
14+
# autocast only supports cuda or cpu
15+
if device_type != 'cuda' or device_type != 'cpu':
16+
return 'cpu'
17+
return device_type

ldm/simplet2i.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ldm.models.diffusion.plms import PLMSSampler
2828
from ldm.models.diffusion.ksampler import KSampler
2929
from ldm.dream.pngwriter import PngWriter
30-
from ldm.dream.devices import choose_torch_device
30+
from ldm.dream.devices import choose_autocast_device, choose_torch_device
3131

3232
"""Simplified text to image API for stable diffusion/latent diffusion
3333
@@ -315,9 +315,7 @@ def process_image(image,seed):
315315
callback=step_callback,
316316
)
317317

318-
device_type = self.device.type # this returns 'mps' on M1
319-
if device_type != 'cuda' or device_type != 'cpu':
320-
device_type = 'cpu'
318+
device_type = choose_autocast_device(self.device)
321319
with scope(device_type), self.model.ema_scope():
322320
for n in trange(iterations, desc='Generating'):
323321
seed_everything(seed)

0 commit comments

Comments
 (0)