Skip to content

Commit 153c93b

Browse files
committed
refactor pngwriter
1 parent 3be1cee commit 153c93b

File tree

4 files changed

+76
-86
lines changed

4 files changed

+76
-86
lines changed

ldm/dream/pngwriter.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,62 +17,32 @@
1717

1818

1919
class PngWriter:
20-
def __init__(self, outdir, prompt=None):
20+
def __init__(self, outdir):
2121
self.outdir = outdir
22-
self.prompt = prompt
23-
self.filepath = None
24-
self.files_written = []
2522
os.makedirs(outdir, exist_ok=True)
2623

27-
def write_image(self, image, seed, upscaled=False):
28-
self.filepath = self.unique_filename(
29-
seed, upscaled, self.filepath
30-
) # will increment name in some sensible way
31-
try:
32-
prompt = f'{self.prompt} -S{seed}'
33-
self.save_image_and_prompt_to_png(image, prompt, self.filepath)
34-
except IOError as e:
35-
print(e)
36-
if not upscaled:
37-
self.files_written.append([self.filepath, seed])
38-
39-
def unique_filename(self, seed, upscaled=False, previouspath=None):
40-
revision = 1
41-
42-
if previouspath is None:
43-
# sort reverse alphabetically until we find max+1
44-
dirlist = sorted(os.listdir(self.outdir), reverse=True)
45-
# find the first filename that matches our pattern or return 000000.0.png
46-
filename = next(
47-
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
48-
'0000000.0.png',
49-
)
50-
basecount = int(filename.split('.', 1)[0])
51-
basecount += 1
52-
filename = f'{basecount:06}.{seed}.png'
53-
return os.path.join(self.outdir, filename)
54-
55-
else:
56-
basename = os.path.basename(previouspath)
57-
x = re.match('^(\d+)\..*\.png', basename)
58-
if not x:
59-
return self.unique_filename(seed, upscaled, previouspath)
60-
61-
basecount = int(x.groups()[0])
62-
series = 0
63-
finished = False
64-
while not finished:
65-
series += 1
66-
filename = f'{basecount:06}.{seed}.png'
67-
path = os.path.join(self.outdir, filename)
68-
finished = not os.path.exists(path)
69-
return os.path.join(self.outdir, filename)
70-
71-
def save_image_and_prompt_to_png(self, image, prompt, path):
24+
# gives the next unique prefix in outdir
25+
def unique_prefix(self):
26+
# sort reverse alphabetically until we find max+1
27+
dirlist = sorted(os.listdir(self.outdir), reverse=True)
28+
# find the first filename that matches our pattern or return 000000.0.png
29+
existing_name = next(
30+
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
31+
'0000000.0.png',
32+
)
33+
basecount = int(existing_name.split('.', 1)[0]) + 1
34+
return f'{basecount:06}'
35+
36+
# saves image named _image_ to outdir/name, writing metadata from prompt
37+
# returns full path of output
38+
def save_image_and_prompt_to_png(self, image, prompt, name):
39+
path = os.path.join(self.outdir, name)
7240
info = PngImagePlugin.PngInfo()
7341
info.add_text('Dream', prompt)
7442
image.save(path, 'PNG', pnginfo=info)
43+
return path
7544

45+
# TODO move this to its own helper function; it's not really a method of pngwriter
7646
def make_grid(self, image_list, rows=None, cols=None):
7747
image_cnt = len(image_list)
7848
if None in (rows, cols):

ldm/dream/server.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,24 +88,25 @@ def do_POST(self):
8888

8989
images_generated = 0 # helps keep track of when upscaling is started
9090
images_upscaled = 0 # helps keep track of when upscaling is completed
91-
pngwriter = PngWriter(
92-
"./outputs/img-samples/", config['prompt'], 1
93-
)
91+
pngwriter = PngWriter("./outputs/img-samples/")
9492

93+
prefix = pngwriter.unique_prefix()
9594
# if upscaling is requested, then this will be called twice, once when
9695
# the images are first generated, and then again when after upscaling
9796
# is complete. The upscaling replaces the original file, so the second
9897
# entry should not be inserted into the image list.
9998
def image_done(image, seed, upscaled=False):
100-
pngwriter.write_image(image, seed, upscaled)
99+
name = f'{prefix}.{seed}.png'
100+
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
101101

102102
# Append post_data to log, but only once!
103103
if not upscaled:
104-
current_image = pngwriter.files_written[-1]
105104
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
106-
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
105+
log.write(f"{path}: {json.dumps(config)}\n")
106+
107+
# TODO fix format of this event
107108
self.wfile.write(bytes(json.dumps(
108-
{'event':'result', 'files':current_image, 'config':config}
109+
{'event': 'result', 'files': [path, seed], 'config': config}
109110
) + '\n',"utf-8"))
110111

111112
# control state of the "postprocessing..." message
@@ -129,22 +130,24 @@ def image_done(image, seed, upscaled=False):
129130
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
130131
) + '\n',"utf-8"))
131132

132-
# TODO: refactor PngWriter:
133-
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
134-
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
133+
step_writer = PngWriter('./outputs/intermediates/')
134+
step_index = 1
135135
def image_progress(sample, step):
136136
if self.canceled.is_set():
137137
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
138138
raise CanceledException
139-
url = None
139+
path = None
140140
# since rendering images is moderately expensive, only render every 5th image
141141
# and don't bother with the last one, since it'll render anyway
142+
nonlocal step_index
142143
if progress_images and step % 5 == 0 and step < steps - 1:
143144
image = self.model._sample_to_image(sample)
144-
step_writer.write_image(image, seed) # TODO PngWriter to return path
145-
url = step_writer.filepath
145+
name = f'{prefix}.{seed}.{step_index}.png'
146+
metadata = f'{prompt} -S{seed} [intermediate]'
147+
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
148+
step_index += 1
146149
self.wfile.write(bytes(json.dumps(
147-
{'event':'step', 'step':step + 1, 'url': url}
150+
{'event': 'step', 'step': step + 1, 'url': path}
148151
) + '\n',"utf-8"))
149152

150153
try:

ldm/simplet2i.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,14 @@ def prompt2png(self, prompt, outdir, **kwargs):
171171
Optional named arguments are the same as those passed to T2I and prompt2image()
172172
"""
173173
results = self.prompt2image(prompt, **kwargs)
174-
pngwriter = PngWriter(outdir, prompt)
175-
for r in results:
176-
pngwriter.write_image(r[0], r[1])
177-
return pngwriter.files_written
174+
pngwriter = PngWriter(outdir)
175+
prefix = pngwriter.unique_prefix()
176+
outputs = []
177+
for image, seed in results:
178+
name = f'{prefix}.{seed}.png'
179+
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
180+
outputs.append([path, seed])
181+
return outputs
178182

179183
def txt2img(self, prompt, **kwargs):
180184
outdir = kwargs.pop('outdir', 'outputs/img-samples')
@@ -349,10 +353,7 @@ def process_image(image,seed):
349353
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
350354
)
351355
if image_callback is not None:
352-
if save_original:
353-
image_callback(image, seed)
354-
else:
355-
image_callback(image, seed, upscaled=True)
356+
image_callback(image, seed, upscaled=True)
356357
else: # no callback passed, so we simply replace old image with rescaled one
357358
result[0] = image
358359

scripts/dream.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -203,24 +203,40 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
203203

204204
# Here is where the images are actually generated!
205205
try:
206-
file_writer = PngWriter(current_outdir, normalized_prompt)
207-
callback = file_writer.write_image if individual_images else None
208-
image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
209-
results = (
210-
file_writer.files_written if individual_images else image_list
211-
)
212-
213-
if do_grid and len(results) > 0:
214-
grid_img = file_writer.make_grid([r[0] for r in results])
215-
filename = file_writer.unique_filename(results[0][1])
216-
seeds = [a[1] for a in results]
217-
results = [[filename, seeds]]
218-
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
219-
file_writer.save_image_and_prompt_to_png(
206+
file_writer = PngWriter(current_outdir)
207+
prefix = file_writer.unique_prefix()
208+
seeds = set()
209+
results = []
210+
grid_images = dict() # seed -> Image, only used if `do_grid`
211+
def image_writer(image, seed, upscaled=False):
212+
if do_grid:
213+
grid_images[seed] = image
214+
else:
215+
if upscaled and opt.save_original:
216+
filename = f'{prefix}.{seed}.postprocessed.png'
217+
else:
218+
filename = f'{prefix}.{seed}.png'
219+
path = file_writer.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{seed}', filename)
220+
if (not upscaled) or opt.save_original:
221+
# only append to results if we didn't overwrite an earlier output
222+
results.append([path, seed])
223+
224+
seeds.add(seed)
225+
226+
t2i.prompt2image(image_callback=image_writer, **vars(opt))
227+
228+
if do_grid and len(grid_images) > 0:
229+
grid_img = file_writer.make_grid(list(grid_images.values()))
230+
first_seed = next(iter(seeds))
231+
filename = f'{prefix}.{first_seed}.png'
232+
# TODO better metadata for grid images
233+
metadata_prompt = f'{normalized_prompt} -S{first_seed}'
234+
path = file_writer.save_image_and_prompt_to_png(
220235
grid_img, metadata_prompt, filename
221236
)
237+
results = [[path, seeds]]
222238

223-
last_seeds = [r[1] for r in results]
239+
last_seeds = list(seeds)
224240

225241
except AssertionError as e:
226242
print(e)

0 commit comments

Comments
 (0)