Skip to content

Commit 6566c22

Browse files
committed
add scalable support for new models using a configs/models.yaml file
2 parents 18cdb55 + 063b4a1 commit 6566c22

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,13 @@ be fast because all the dependencies are already loaded.
739739
Anyone who wishes to contribute to this project, whether
740740
documentation, features, bug fixes, code cleanup, testing, or code
741741
reviews, is very much encouraged to do so. If you are unfamiliar with
742+
<<<<<<< HEAD
742743
how to contribute to GitHub projects, here is a [Getting Started
743744
Guide](https://opensource.com/article/19/7/create-pull-request-github).
745+
=======
746+
how to contribute to GitHub projects, here is a (Getting Started
747+
Guide)[https://opensource.com/article/19/7/create-pull-request-github].
748+
>>>>>>> maddavid123-main
744749
745750
A full set of contribution guidelines, along with templates, are in
746751
progress, but for now the most important thing is to **make your pull

configs/models.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# This file describes the alternative machine learning models
2+
# available to the dream script.
3+
#
4+
# To add a new model, follow the examples below. Each
5+
# model requires a model config file, a weights file,
6+
# and the width and height of the images it
7+
# was trained on.
8+
9+
laion400m:
10+
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
11+
weights: models/ldm/text2img-large/model.ckpt
12+
width: 256
13+
height: 256
14+
stable-diffusion-1.4:
15+
config: configs/stable-diffusion/v1-inference.yaml
16+
weights: models/ldm/stable-diffusion-v1/model.ckpt
17+
width: 512
18+
height: 512

scripts/dream.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,29 @@
1313
from ldm.dream.pngwriter import PngWriter, PromptFormatter
1414
from ldm.dream.server import DreamServer, ThreadingDreamServer
1515
from ldm.dream.image_util import make_grid
16+
from omegaconf import OmegaConf
1617

1718
def main():
1819
"""Initialize command-line parsers and the diffusion model"""
1920
arg_parser = create_argv_parser()
2021
opt = arg_parser.parse_args()
22+
2123
if opt.laion400m:
22-
# defaults suitable to the older latent diffusion weights
23-
width = 256
24-
height = 256
25-
config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml'
26-
weights = 'models/ldm/text2img-large/model.ckpt'
27-
else:
28-
# some defaults suitable for stable diffusion weights
29-
width = 512
30-
height = 512
31-
config = 'configs/stable-diffusion/v1-inference.yaml'
32-
if '.ckpt' in opt.weights:
33-
weights = opt.weights
34-
else:
35-
weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt'
24+
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
25+
sys.exit(-1)
26+
if opt.weights != 'model':
27+
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
28+
sys.exit(-1)
29+
30+
try:
31+
models = OmegaConf.load(opt.config)
32+
width = models[opt.model].width
33+
height = models[opt.model].height
34+
config = models[opt.model].config
35+
weights = models[opt.model].weights
36+
except (FileNotFoundError, IOError, KeyError) as e:
37+
print(f'{e}. Aborting.')
38+
sys.exit(-1)
3639

3740
print('* Initializing, be patient...\n')
3841
sys.path.append('.')
@@ -482,6 +485,16 @@ def create_argv_parser():
482485
default='cuda',
483486
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
484487
)
488+
parser.add_argument(
489+
'--model',
490+
default='stable-diffusion-1.4',
491+
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
492+
)
493+
parser.add_argument(
494+
'--config',
495+
default ='configs/models.yaml',
496+
help ='Path to configuration file for alternate models.',
497+
)
485498
return parser
486499

487500

0 commit comments

Comments
 (0)