|
13 | 13 | from ldm.dream.pngwriter import PngWriter, PromptFormatter |
14 | 14 | from ldm.dream.server import DreamServer, ThreadingDreamServer |
15 | 15 | from ldm.dream.image_util import make_grid |
| 16 | +from omegaconf import OmegaConf |
16 | 17 |
|
17 | 18 | def main(): |
18 | 19 | """Initialize command-line parsers and the diffusion model""" |
19 | 20 | arg_parser = create_argv_parser() |
20 | 21 | opt = arg_parser.parse_args() |
| 22 | + |
21 | 23 | if opt.laion400m: |
22 | | - # defaults suitable to the older latent diffusion weights |
23 | | - width = 256 |
24 | | - height = 256 |
25 | | - config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml' |
26 | | - weights = 'models/ldm/text2img-large/model.ckpt' |
27 | | - else: |
28 | | - # some defaults suitable for stable diffusion weights |
29 | | - width = 512 |
30 | | - height = 512 |
31 | | - config = 'configs/stable-diffusion/v1-inference.yaml' |
32 | | - if '.ckpt' in opt.weights: |
33 | | - weights = opt.weights |
34 | | - else: |
35 | | - weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt' |
| 24 | + print('--laion400m flag has been deprecated. Please use --model laion400m instead.') |
| 25 | + sys.exit(-1) |
| 26 | + if opt.weights != 'model': |
| 27 | + print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') |
| 28 | + sys.exit(-1) |
| 29 | + |
| 30 | + try: |
| 31 | + models = OmegaConf.load(opt.config) |
| 32 | + width = models[opt.model].width |
| 33 | + height = models[opt.model].height |
| 34 | + config = models[opt.model].config |
| 35 | + weights = models[opt.model].weights |
| 36 | + except (FileNotFoundError, IOError, KeyError) as e: |
| 37 | + print(f'{e}. Aborting.') |
| 38 | + sys.exit(-1) |
36 | 39 |
|
37 | 40 | print('* Initializing, be patient...\n') |
38 | 41 | sys.path.append('.') |
@@ -482,6 +485,16 @@ def create_argv_parser(): |
482 | 485 | default='cuda', |
483 | 486 | help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available" |
484 | 487 | ) |
| 488 | + parser.add_argument( |
| 489 | + '--model', |
| 490 | + default='stable-diffusion-1.4', |
| 491 | + help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")', |
| 492 | + ) |
| 493 | + parser.add_argument( |
| 494 | + '--config', |
| 495 | + default ='configs/models.yaml', |
| 496 | + help ='Path to configuration file for alternate models.', |
| 497 | + ) |
485 | 498 | return parser |
486 | 499 |
|
487 | 500 |
|
|
0 commit comments