Skip to content

Commit 063b4a1

Browse files
committed
add ability to specify location of config file (models.yaml)
2 parents 51278c7 + 68eabab commit 063b4a1

File tree

2 files changed

+46
-17
lines changed

2 files changed

+46
-17
lines changed

configs/models.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# This file describes the alternative machine learning models
2+
# available to the dream script.
3+
#
4+
# To add a new model, follow the examples below. Each
5+
# model requires a model config file, a weights file,
6+
# and the width and height of the images it
7+
# was trained on.
8+
9+
laion400m:
10+
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
11+
weights: models/ldm/text2img-large/model.ckpt
12+
width: 256
13+
height: 256
14+
stable-diffusion-1.4:
15+
config: configs/stable-diffusion/v1-inference.yaml
16+
weights: models/ldm/stable-diffusion-v1/model.ckpt
17+
width: 512
18+
height: 512

scripts/dream.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,34 @@
99
import copy
1010
import warnings
1111
import time
12-
from ldm.dream.devices import choose_torch_device
1312
import ldm.dream.readline
1413
from ldm.dream.pngwriter import PngWriter, PromptFormatter
1514
from ldm.dream.server import DreamServer, ThreadingDreamServer
1615
from ldm.dream.image_util import make_grid
16+
from omegaconf import OmegaConf
1717

1818
def main():
1919
"""Initialize command-line parsers and the diffusion model"""
2020
arg_parser = create_argv_parser()
2121
opt = arg_parser.parse_args()
22+
2223
if opt.laion400m:
23-
# defaults suitable to the older latent diffusion weights
24-
width = 256
25-
height = 256
26-
config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml'
27-
weights = 'models/ldm/text2img-large/model.ckpt'
28-
else:
29-
# some defaults suitable for stable diffusion weights
30-
width = 512
31-
height = 512
32-
config = 'configs/stable-diffusion/v1-inference.yaml'
33-
if '.ckpt' in opt.weights:
34-
weights = opt.weights
35-
else:
36-
weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt'
24+
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
25+
sys.exit(-1)
26+
if opt.weights != 'model':
27+
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
28+
sys.exit(-1)
29+
30+
try:
31+
print(f'attempting to load {opt.config}')
32+
models = OmegaConf.load(opt.config)
33+
width = models[opt.model].width
34+
height = models[opt.model].height
35+
config = models[opt.model].config
36+
weights = models[opt.model].weights
37+
except (FileNotFoundError, IOError, KeyError) as e:
38+
print(f'{e}. Aborting.')
39+
sys.exit(-1)
3740

3841
print('* Initializing, be patient...\n')
3942
sys.path.append('.')
@@ -348,8 +351,6 @@ def create_argv_parser():
348351
dest='full_precision',
349352
action='store_true',
350353
help='Use slower full precision math for calculations',
351-
# MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237
352-
default=choose_torch_device() == 'mps',
353354
)
354355
parser.add_argument(
355356
'-g',
@@ -429,6 +430,16 @@ def create_argv_parser():
429430
default='cuda',
430431
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
431432
)
433+
parser.add_argument(
434+
'--model',
435+
default='stable-diffusion-1.4',
436+
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
437+
)
438+
parser.add_argument(
439+
'--config',
440+
default ='configs/models.yaml',
441+
help ='Path to configuration file for alternate models.',
442+
)
432443
return parser
433444

434445

0 commit comments

Comments
 (0)