File tree Expand file tree Collapse file tree 2 files changed +9
-5
lines changed
Expand file tree Collapse file tree 2 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -8,4 +8,10 @@ def choose_torch_device() -> str:
88 return 'mps'
99 return 'cpu'
1010
11-
11+ def choose_autocast_device (device ) -> str :
12+ '''Returns an autocast compatible device from a torch device'''
13+ device_type = device .type # this returns 'mps' on M1
14+ # autocast only supports cuda or cpu
15+ if device_type != 'cuda' or device_type != 'cpu' :
16+ return 'cpu'
17+ return device_type
Original file line number Diff line number Diff line change 2727from ldm .models .diffusion .plms import PLMSSampler
2828from ldm .models .diffusion .ksampler import KSampler
2929from ldm .dream .pngwriter import PngWriter
30- from ldm .dream .devices import choose_torch_device
30+ from ldm .dream .devices import choose_autocast_device , choose_torch_device
3131
3232"""Simplified text to image API for stable diffusion/latent diffusion
3333
@@ -315,9 +315,7 @@ def process_image(image,seed):
315315 callback = step_callback ,
316316 )
317317
318- device_type = self .device .type # this returns 'mps' on M1
319- if device_type != 'cuda' or device_type != 'cpu' :
320- device_type = 'cpu'
318+ device_type = choose_autocast_device (self .device )
321319 with scope (device_type ), self .model .ema_scope ():
322320 for n in trange (iterations , desc = 'Generating' ):
323321 seed_everything (seed )
You can’t perform that action at this time.
0 commit comments