Skip to content

Commit 146e75a

Browse files
committed
Merge branch 'bakkot-refactor-pngwriter-2' into main
This fixes regressions in the WebGUI and makes maintenance of pngwriter easier.
2 parents 462a196 + 8a2b849 commit 146e75a

File tree

6 files changed

+98
-114
lines changed

6 files changed

+98
-114
lines changed

ldm/dream/image_util.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from math import sqrt, floor, ceil
12
from PIL import Image
23

34
class InitImageResizer():
@@ -51,4 +52,22 @@ def resize(self,width=None,height=None) -> Image:
5152

5253
return new_image
5354

54-
55+
def make_grid(image_list, rows=None, cols=None):
56+
image_cnt = len(image_list)
57+
if None in (rows, cols):
58+
rows = floor(sqrt(image_cnt)) # try to make it square
59+
cols = ceil(image_cnt / rows)
60+
width = image_list[0].width
61+
height = image_list[0].height
62+
63+
grid_img = Image.new('RGB', (width * cols, height * rows))
64+
i = 0
65+
for r in range(0, rows):
66+
for c in range(0, cols):
67+
if i >= len(image_list):
68+
break
69+
grid_img.paste(image_list[i], (c * width, r * height))
70+
i = i + 1
71+
72+
return grid_img
73+

ldm/dream/pngwriter.py

Lines changed: 20 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,97 +2,42 @@
22
Two helper classes for dealing with PNG images and their path names.
33
PngWriter -- Converts Images generated by T2I into PNGs, finds
44
appropriate names for them, and writes prompt metadata
5-
into the PNG. Intended to be subclassable in order to
6-
create more complex naming schemes, including using the
7-
prompt for file/directory names.
5+
into the PNG.
86
PromptFormatter -- Utility for converting a Namespace of prompt parameters
97
back into a formatted prompt string with command-line switches.
108
"""
119
import os
1210
import re
13-
from math import sqrt, floor, ceil
14-
from PIL import Image, PngImagePlugin
11+
from PIL import PngImagePlugin
1512

1613
# -------------------image generation utils-----
1714

1815

1916
class PngWriter:
20-
def __init__(self, outdir, prompt=None):
17+
def __init__(self, outdir):
2118
self.outdir = outdir
22-
self.prompt = prompt
23-
self.filepath = None
24-
self.files_written = []
2519
os.makedirs(outdir, exist_ok=True)
2620

27-
def write_image(self, image, seed, upscaled=False):
28-
self.filepath = self.unique_filename(
29-
seed, upscaled, self.filepath
30-
) # will increment name in some sensible way
31-
try:
32-
prompt = f'{self.prompt} -S{seed}'
33-
self.save_image_and_prompt_to_png(image, prompt, self.filepath)
34-
except IOError as e:
35-
print(e)
36-
if not upscaled:
37-
self.files_written.append([self.filepath, seed])
38-
39-
def unique_filename(self, seed, upscaled=False, previouspath=None):
40-
revision = 1
41-
42-
if previouspath is None:
43-
# sort reverse alphabetically until we find max+1
44-
dirlist = sorted(os.listdir(self.outdir), reverse=True)
45-
# find the first filename that matches our pattern or return 000000.0.png
46-
filename = next(
47-
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
48-
'0000000.0.png',
49-
)
50-
basecount = int(filename.split('.', 1)[0])
51-
basecount += 1
52-
filename = f'{basecount:06}.{seed}.png'
53-
return os.path.join(self.outdir, filename)
54-
55-
else:
56-
basename = os.path.basename(previouspath)
57-
x = re.match('^(\d+)\..*\.png', basename)
58-
if not x:
59-
return self.unique_filename(seed, upscaled, previouspath)
60-
61-
basecount = int(x.groups()[0])
62-
series = 0
63-
finished = False
64-
while not finished:
65-
series += 1
66-
filename = f'{basecount:06}.{seed}.png'
67-
path = os.path.join(self.outdir, filename)
68-
if os.path.exists(path) and upscaled:
69-
break
70-
finished = not os.path.exists(path)
71-
return os.path.join(self.outdir, filename)
72-
73-
def save_image_and_prompt_to_png(self, image, prompt, path):
21+
# gives the next unique prefix in outdir
22+
def unique_prefix(self):
23+
# sort reverse alphabetically until we find max+1
24+
dirlist = sorted(os.listdir(self.outdir), reverse=True)
25+
# find the first filename that matches our pattern or return 000000.0.png
26+
existing_name = next(
27+
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
28+
'0000000.0.png',
29+
)
30+
basecount = int(existing_name.split('.', 1)[0]) + 1
31+
return f'{basecount:06}'
32+
33+
# saves image named _image_ to outdir/name, writing metadata from prompt
34+
# returns full path of output
35+
def save_image_and_prompt_to_png(self, image, prompt, name):
36+
path = os.path.join(self.outdir, name)
7437
info = PngImagePlugin.PngInfo()
7538
info.add_text('Dream', prompt)
7639
image.save(path, 'PNG', pnginfo=info)
77-
78-
def make_grid(self, image_list, rows=None, cols=None):
79-
image_cnt = len(image_list)
80-
if None in (rows, cols):
81-
rows = floor(sqrt(image_cnt)) # try to make it square
82-
cols = ceil(image_cnt / rows)
83-
width = image_list[0].width
84-
height = image_list[0].height
85-
86-
grid_img = Image.new('RGB', (width * cols, height * rows))
87-
i = 0
88-
for r in range(0, rows):
89-
for c in range(0, cols):
90-
if i>=len(image_list):
91-
break
92-
grid_img.paste(image_list[i], (c * width, r * height))
93-
i = i + 1
94-
95-
return grid_img
40+
return path
9641

9742

9843
class PromptFormatter:

ldm/dream/server.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,24 @@ def do_POST(self):
8888

8989
images_generated = 0 # helps keep track of when upscaling is started
9090
images_upscaled = 0 # helps keep track of when upscaling is completed
91-
pngwriter = PngWriter(
92-
"./outputs/img-samples/", config['prompt'], 1
93-
)
91+
pngwriter = PngWriter("./outputs/img-samples/")
9492

93+
prefix = pngwriter.unique_prefix()
9594
# if upscaling is requested, then this will be called twice, once when
9695
# the images are first generated, and then again when after upscaling
9796
# is complete. The upscaling replaces the original file, so the second
9897
# entry should not be inserted into the image list.
9998
def image_done(image, seed, upscaled=False):
100-
pngwriter.write_image(image, seed, upscaled)
99+
name = f'{prefix}.{seed}.png'
100+
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
101101

102102
# Append post_data to log, but only once!
103103
if not upscaled:
104-
current_image = pngwriter.files_written[-1]
105104
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
106-
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
105+
log.write(f"{path}: {json.dumps(config)}\n")
106+
107107
self.wfile.write(bytes(json.dumps(
108-
{'event':'result', 'files':current_image, 'config':config}
108+
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
109109
) + '\n',"utf-8"))
110110

111111
# control state of the "postprocessing..." message
@@ -129,22 +129,24 @@ def image_done(image, seed, upscaled=False):
129129
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
130130
) + '\n',"utf-8"))
131131

132-
# TODO: refactor PngWriter:
133-
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
134-
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
132+
step_writer = PngWriter('./outputs/intermediates/')
133+
step_index = 1
135134
def image_progress(sample, step):
136135
if self.canceled.is_set():
137136
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
138137
raise CanceledException
139-
url = None
138+
path = None
140139
# since rendering images is moderately expensive, only render every 5th image
141140
# and don't bother with the last one, since it'll render anyway
141+
nonlocal step_index
142142
if progress_images and step % 5 == 0 and step < steps - 1:
143143
image = self.model._sample_to_image(sample)
144-
step_writer.write_image(image, seed) # TODO PngWriter to return path
145-
url = step_writer.filepath
144+
name = f'{prefix}.{seed}.{step_index}.png'
145+
metadata = f'{prompt} -S{seed} [intermediate]'
146+
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
147+
step_index += 1
146148
self.wfile.write(bytes(json.dumps(
147-
{'event':'step', 'step':step + 1, 'url': url}
149+
{'event': 'step', 'step': step + 1, 'url': path}
148150
) + '\n',"utf-8"))
149151

150152
try:

ldm/simplet2i.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,14 @@ def prompt2png(self, prompt, outdir, **kwargs):
171171
Optional named arguments are the same as those passed to T2I and prompt2image()
172172
"""
173173
results = self.prompt2image(prompt, **kwargs)
174-
pngwriter = PngWriter(outdir, prompt)
175-
for r in results:
176-
pngwriter.write_image(r[0], r[1])
177-
return pngwriter.files_written
174+
pngwriter = PngWriter(outdir)
175+
prefix = pngwriter.unique_prefix()
176+
outputs = []
177+
for image, seed in results:
178+
name = f'{prefix}.{seed}.png'
179+
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
180+
outputs.append([path, seed])
181+
return outputs
178182

179183
def txt2img(self, prompt, **kwargs):
180184
outdir = kwargs.pop('outdir', 'outputs/img-samples')
@@ -349,10 +353,7 @@ def process_image(image,seed):
349353
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
350354
)
351355
if image_callback is not None:
352-
if save_original:
353-
image_callback(image, seed)
354-
else:
355-
image_callback(image, seed, upscaled=True)
356+
image_callback(image, seed, upscaled=True)
356357
else: # no callback passed, so we simply replace old image with rescaled one
357358
result[0] = image
358359

scripts/dream.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import ldm.dream.readline
1313
from ldm.dream.pngwriter import PngWriter, PromptFormatter
1414
from ldm.dream.server import DreamServer, ThreadingDreamServer
15+
from ldm.dream.image_util import make_grid
1516

1617
def main():
1718
"""Initialize command-line parsers and the diffusion model"""
@@ -203,24 +204,40 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
203204

204205
# Here is where the images are actually generated!
205206
try:
206-
file_writer = PngWriter(current_outdir, normalized_prompt)
207-
callback = file_writer.write_image if individual_images else None
208-
image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
209-
results = (
210-
file_writer.files_written if individual_images else image_list
211-
)
212-
213-
if do_grid and len(results) > 0:
214-
grid_img = file_writer.make_grid([r[0] for r in results])
215-
filename = file_writer.unique_filename(results[0][1])
216-
seeds = [a[1] for a in results]
217-
results = [[filename, seeds]]
218-
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
219-
file_writer.save_image_and_prompt_to_png(
207+
file_writer = PngWriter(current_outdir)
208+
prefix = file_writer.unique_prefix()
209+
seeds = set()
210+
results = []
211+
grid_images = dict() # seed -> Image, only used if `do_grid`
212+
def image_writer(image, seed, upscaled=False):
213+
if do_grid:
214+
grid_images[seed] = image
215+
else:
216+
if upscaled and opt.save_original:
217+
filename = f'{prefix}.{seed}.postprocessed.png'
218+
else:
219+
filename = f'{prefix}.{seed}.png'
220+
path = file_writer.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{seed}', filename)
221+
if (not upscaled) or opt.save_original:
222+
# only append to results if we didn't overwrite an earlier output
223+
results.append([path, seed])
224+
225+
seeds.add(seed)
226+
227+
t2i.prompt2image(image_callback=image_writer, **vars(opt))
228+
229+
if do_grid and len(grid_images) > 0:
230+
grid_img = make_grid(list(grid_images.values()))
231+
first_seed = next(iter(seeds))
232+
filename = f'{prefix}.{first_seed}.png'
233+
# TODO better metadata for grid images
234+
metadata_prompt = f'{normalized_prompt} -S{first_seed}'
235+
path = file_writer.save_image_and_prompt_to_png(
220236
grid_img, metadata_prompt, filename
221237
)
238+
results = [[path, seeds]]
222239

223-
last_seeds = [r[1] for r in results]
240+
last_seeds = list(seeds)
224241

225242
except AssertionError as e:
226243
print(e)

static/dream_web/index.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ async function generateSubmit(form) {
9595
if (data.event === 'result') {
9696
noOutputs = false;
9797
document.querySelector("#no-results-message")?.remove();
98-
appendOutput(data.files[0],data.files[1],data.config);
98+
appendOutput(data.url, data.seed, data.config);
9999
progressEle.setAttribute('value', 0);
100100
progressEle.setAttribute('max', totalSteps);
101101
progressImageEle.src = BLANK_IMAGE_URL;

0 commit comments

Comments
 (0)