|
9 | 9 | import copy |
10 | 10 | import warnings |
11 | 11 | import time |
12 | | -from ldm.dream.devices import choose_torch_device |
13 | 12 | import ldm.dream.readline |
14 | 13 | from ldm.dream.pngwriter import PngWriter, PromptFormatter |
15 | 14 | from ldm.dream.server import DreamServer, ThreadingDreamServer |
16 | 15 | from ldm.dream.image_util import make_grid |
| 16 | +from omegaconf import OmegaConf |
17 | 17 |
|
18 | 18 | def main(): |
19 | 19 | """Initialize command-line parsers and the diffusion model""" |
20 | 20 | arg_parser = create_argv_parser() |
21 | 21 | opt = arg_parser.parse_args() |
| 22 | + |
22 | 23 | if opt.laion400m: |
23 | | - # defaults suitable to the older latent diffusion weights |
24 | | - width = 256 |
25 | | - height = 256 |
26 | | - config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml' |
27 | | - weights = 'models/ldm/text2img-large/model.ckpt' |
28 | | - else: |
29 | | - # some defaults suitable for stable diffusion weights |
30 | | - width = 512 |
31 | | - height = 512 |
32 | | - config = 'configs/stable-diffusion/v1-inference.yaml' |
33 | | - if '.ckpt' in opt.weights: |
34 | | - weights = opt.weights |
35 | | - else: |
36 | | - weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt' |
| 24 | + print('--laion400m flag has been deprecated. Please use --model laion400m instead.') |
| 25 | + sys.exit(-1) |
| 26 | + if opt.weights != 'model': |
| 27 | + print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') |
| 28 | + sys.exit(-1) |
| 29 | + |
| 30 | + try: |
| 31 | + print(f'attempting to load {opt.config}') |
| 32 | + models = OmegaConf.load(opt.config) |
| 33 | + width = models[opt.model].width |
| 34 | + height = models[opt.model].height |
| 35 | + config = models[opt.model].config |
| 36 | + weights = models[opt.model].weights |
| 37 | + except (FileNotFoundError, IOError, KeyError) as e: |
| 38 | + print(f'{e}. Aborting.') |
| 39 | + sys.exit(-1) |
37 | 40 |
|
38 | 41 | print('* Initializing, be patient...\n') |
39 | 42 | sys.path.append('.') |
@@ -348,8 +351,6 @@ def create_argv_parser(): |
348 | 351 | dest='full_precision', |
349 | 352 | action='store_true', |
350 | 353 | help='Use slower full precision math for calculations', |
351 | | - # MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237 |
352 | | - default=choose_torch_device() == 'mps', |
353 | 354 | ) |
354 | 355 | parser.add_argument( |
355 | 356 | '-g', |
@@ -429,6 +430,16 @@ def create_argv_parser(): |
429 | 430 | default='cuda', |
430 | 431 | help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available" |
431 | 432 | ) |
| 433 | + parser.add_argument( |
| 434 | + '--model', |
| 435 | + default='stable-diffusion-1.4', |
| 436 | + help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")', |
| 437 | + ) |
| 438 | + parser.add_argument( |
| 439 | + '--config', |
| 440 | + default ='configs/models.yaml', |
| 441 | + help ='Path to configuration file for alternate models.', |
| 442 | + ) |
432 | 443 | return parser |
433 | 444 |
|
434 | 445 |
|
|
0 commit comments