Skip to content

Commit 3ee82d8

Browse files
committed
Merge branch 'toffaletti-dream-m1' into main
This provides support for Apple M1 hardware
2 parents 833de06 + 629ca09 commit 3ee82d8

File tree

4 files changed

+60
-41
lines changed

4 files changed

+60
-41
lines changed

environment-mac.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ dependencies:
5252
- -e git+https://github.com/huggingface/[email protected]#egg=diffusers
5353
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
5454
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
55-
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
55+
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
5656
- -e .
5757
variables:
5858
PYTORCH_ENABLE_MPS_FALLBACK: 1

ldm/dream/devices.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,10 @@ def choose_torch_device() -> str:
88
return 'mps'
99
return 'cpu'
1010

11-
11+
def choose_autocast_device(device) -> str:
12+
'''Returns an autocast compatible device from a torch device'''
13+
device_type = device.type # this returns 'mps' on M1
14+
# autocast only supports cuda or cpu
15+
if device_type not in ('cuda','cpu'):
16+
return 'cpu'
17+
return device_type

ldm/simplet2i.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import random
1010
import os
11+
import traceback
1112
from omegaconf import OmegaConf
1213
from PIL import Image
1314
from tqdm import tqdm, trange
@@ -28,7 +29,7 @@
2829
from ldm.models.diffusion.ksampler import KSampler
2930
from ldm.dream.pngwriter import PngWriter
3031
from ldm.dream.image_util import InitImageResizer
31-
from ldm.dream.devices import choose_torch_device
32+
from ldm.dream.devices import choose_autocast_device, choose_torch_device
3233

3334
"""Simplified text to image API for stable diffusion/latent diffusion
3435
@@ -114,26 +115,28 @@ class T2I:
114115
"""
115116

116117
def __init__(
117-
self,
118-
iterations=1,
119-
steps=50,
120-
seed=None,
121-
cfg_scale=7.5,
122-
weights='models/ldm/stable-diffusion-v1/model.ckpt',
123-
config='configs/stable-diffusion/v1-inference.yaml',
124-
grid=False,
125-
width=512,
126-
height=512,
127-
sampler_name='k_lms',
128-
latent_channels=4,
129-
downsampling_factor=8,
130-
ddim_eta=0.0, # deterministic
131-
precision='autocast',
132-
full_precision=False,
133-
strength=0.75, # default in scripts/img2img.py
134-
embedding_path=None,
135-
# just to keep track of this parameter when regenerating prompt
136-
latent_diffusion_weights=False,
118+
self,
119+
iterations=1,
120+
steps=50,
121+
seed=None,
122+
cfg_scale=7.5,
123+
weights='models/ldm/stable-diffusion-v1/model.ckpt',
124+
config='configs/stable-diffusion/v1-inference.yaml',
125+
grid=False,
126+
width=512,
127+
height=512,
128+
sampler_name='k_lms',
129+
latent_channels=4,
130+
downsampling_factor=8,
131+
ddim_eta=0.0, # deterministic
132+
precision='autocast',
133+
full_precision=False,
134+
strength=0.75, # default in scripts/img2img.py
135+
embedding_path=None,
136+
device_type = 'cuda',
137+
# just to keep track of this parameter when regenerating prompt
138+
# needs to be replaced when new configuration system implemented.
139+
latent_diffusion_weights=False,
137140
):
138141
self.iterations = iterations
139142
self.width = width
@@ -151,11 +154,17 @@ def __init__(
151154
self.full_precision = full_precision
152155
self.strength = strength
153156
self.embedding_path = embedding_path
157+
self.device_type = device_type
154158
self.model = None # empty for now
155159
self.sampler = None
156160
self.device = None
157161
self.latent_diffusion_weights = latent_diffusion_weights
158162

163+
if device_type == 'cuda' and not torch.cuda.is_available():
164+
device_type = choose_torch_device()
165+
print(">> cuda not available, using device", device_type)
166+
self.device = torch.device(device_type)
167+
159168
# for VRAM usage statistics
160169
device_type = choose_torch_device()
161170
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
@@ -312,8 +321,9 @@ def process_image(image,seed):
312321
callback=step_callback,
313322
)
314323

315-
with scope(self.device.type), self.model.ema_scope():
316-
for n in trange(iterations, desc='>> Generating'):
324+
device_type = choose_autocast_device(self.device)
325+
with scope(device_type), self.model.ema_scope():
326+
for n in trange(iterations, desc='Generating'):
317327
seed_everything(seed)
318328
image = next(images_iterator)
319329
results.append([image, seed])
@@ -346,7 +356,7 @@ def process_image(image,seed):
346356
)
347357
except Exception as e:
348358
print(
349-
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
359+
f'>> Error running RealESRGAN - Your image was not upscaled.\n{e}'
350360
)
351361
if image_callback is not None:
352362
if save_original:
@@ -359,11 +369,11 @@ def process_image(image,seed):
359369
except KeyboardInterrupt:
360370
print('*interrupted*')
361371
print(
362-
'Partial results will be returned; if --grid was requested, nothing will be returned.'
372+
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
363373
)
364374
except RuntimeError as e:
365-
print(str(e))
366-
print('Are you sure your system has an adequate NVIDIA GPU?')
375+
print(traceback.format_exc(), file=sys.stderr)
376+
print('>> Are you sure your system has an adequate NVIDIA GPU?')
367377

368378
toc = time.time()
369379
print('>> Usage stats:')
@@ -464,7 +474,6 @@ def _img2img(
464474
)
465475

466476
t_enc = int(strength * steps)
467-
# print(f"target t_enc is {t_enc} steps")
468477

469478
while True:
470479
uc, c = self._get_uc_and_c(prompt, skip_normalize)
@@ -515,7 +524,7 @@ def _sample_to_image(self, samples):
515524
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
516525
if len(x_samples) != 1:
517526
raise Exception(
518-
f'expected to get a single image, but got {len(x_samples)}')
527+
f'>> expected to get a single image, but got {len(x_samples)}')
519528
x_sample = 255.0 * rearrange(
520529
x_samples[0].cpu().numpy(), 'c h w -> h w c'
521530
)
@@ -525,17 +534,12 @@ def _new_seed(self):
525534
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
526535
return self.seed
527536

528-
def _get_device(self):
529-
device_type = choose_torch_device()
530-
return torch.device(device_type)
531-
532537
def load_model(self):
533538
"""Load and initialize the model from configuration variables passed at object creation time"""
534539
if self.model is None:
535540
seed_everything(self.seed)
536541
try:
537542
config = OmegaConf.load(self.config)
538-
self.device = self._get_device()
539543
model = self._load_model_from_config(config, self.weights)
540544
if self.embedding_path is not None:
541545
model.embedding_manager.load(
@@ -544,12 +548,10 @@ def load_model(self):
544548
self.model = model.to(self.device)
545549
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
546550
self.model.cond_stage_model.device = self.device
547-
except AttributeError:
548-
import traceback
549-
print(
550-
'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
551+
except AttributeError as e:
552+
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
551553
print(traceback.format_exc(), file=sys.stderr)
552-
raise SystemExit
554+
raise SystemExit from e
553555

554556
self._set_sampler()
555557

scripts/dream.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import copy
1010
import warnings
1111
import time
12+
from ldm.dream.devices import choose_torch_device
1213
import ldm.dream.readline
1314
from ldm.dream.pngwriter import PngWriter, PromptFormatter
1415
from ldm.dream.server import DreamServer, ThreadingDreamServer
@@ -60,6 +61,7 @@ def main():
6061
# this is solely for recreating the prompt
6162
latent_diffusion_weights=opt.laion400m,
6263
embedding_path=opt.embedding_path,
64+
device_type=opt.device
6365
)
6466

6567
# make sure the output directory exists
@@ -346,6 +348,8 @@ def create_argv_parser():
346348
dest='full_precision',
347349
action='store_true',
348350
help='Use slower full precision math for calculations',
351+
# MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237
352+
default=choose_torch_device() == 'mps',
349353
)
350354
parser.add_argument(
351355
'-g',
@@ -418,6 +422,13 @@ def create_argv_parser():
418422
default='model',
419423
help='Indicates the Stable Diffusion model to use.',
420424
)
425+
parser.add_argument(
426+
'--device',
427+
'-d',
428+
type=str,
429+
default='cuda',
430+
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
431+
)
421432
return parser
422433

423434

0 commit comments

Comments
 (0)