Skip to content

Commit 6c0dd9b

Browse files
CapableWeblstein
authored andcommitted
Add back old dream.py as legacy_api.py
This commit "reverts" the new API changes by extracting the old functionality into new files. The work is based on the commit `803a51d5adca7e6e28491fc414fd3937bee7cb79` PngWriter regained PromptFormatter as old server used that. `server_legacy.py` is the old server that `dream.py` used. Finally `legacy_api.py` is what `dream.py` used to be at the mentioned commit. One manually run test has been added in order to be able to test compatibility with the old API, currently just testing that the API endpoint works the same way + the image hash is the same as it used to be before.
1 parent ca6385e commit 6c0dd9b

File tree

4 files changed

+1017
-0
lines changed

4 files changed

+1017
-0
lines changed

ldm/invoke/pngwriter.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,43 @@ def write_metadata(img_path:str, meta:dict):
6666
info = PngImagePlugin.PngInfo()
6767
info.add_text('sd-metadata', json.dumps(meta))
6868
im.save(img_path,'PNG',pnginfo=info)
69+
70+
class PromptFormatter:
71+
def __init__(self, t2i, opt):
72+
self.t2i = t2i
73+
self.opt = opt
74+
75+
# note: the t2i object should provide all these values.
76+
# there should be no need to or against opt values
77+
def normalize_prompt(self):
78+
"""Normalize the prompt and switches"""
79+
t2i = self.t2i
80+
opt = self.opt
81+
82+
switches = list()
83+
switches.append(f'"{opt.prompt}"')
84+
switches.append(f'-s{opt.steps or t2i.steps}')
85+
switches.append(f'-W{opt.width or t2i.width}')
86+
switches.append(f'-H{opt.height or t2i.height}')
87+
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
88+
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
89+
# to do: put model name into the t2i object
90+
# switches.append(f'--model{t2i.model_name}')
91+
if opt.seamless or t2i.seamless:
92+
switches.append(f'--seamless')
93+
if opt.init_img:
94+
switches.append(f'-I{opt.init_img}')
95+
if opt.fit:
96+
switches.append(f'--fit')
97+
if opt.strength and opt.init_img is not None:
98+
switches.append(f'-f{opt.strength or t2i.strength}')
99+
if opt.gfpgan_strength:
100+
switches.append(f'-G{opt.gfpgan_strength}')
101+
if opt.upscale:
102+
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
103+
if opt.variation_amount > 0:
104+
switches.append(f'-v{opt.variation_amount}')
105+
if opt.with_variations:
106+
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations)
107+
switches.append(f'-V{formatted_variations}')
108+
return ' '.join(switches)

ldm/invoke/server_legacy.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import argparse
2+
import json
3+
import base64
4+
import mimetypes
5+
import os
6+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
7+
from ldm.invoke.pngwriter import PngWriter, PromptFormatter
8+
from threading import Event
9+
10+
def build_opt(post_data, seed, gfpgan_model_exists):
11+
opt = argparse.Namespace()
12+
setattr(opt, 'prompt', post_data['prompt'])
13+
setattr(opt, 'init_img', post_data['initimg'])
14+
setattr(opt, 'strength', float(post_data['strength']))
15+
setattr(opt, 'iterations', int(post_data['iterations']))
16+
setattr(opt, 'steps', int(post_data['steps']))
17+
setattr(opt, 'width', int(post_data['width']))
18+
setattr(opt, 'height', int(post_data['height']))
19+
setattr(opt, 'seamless', 'seamless' in post_data)
20+
setattr(opt, 'fit', 'fit' in post_data)
21+
setattr(opt, 'mask', 'mask' in post_data)
22+
setattr(opt, 'invert_mask', 'invert_mask' in post_data)
23+
setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))
24+
setattr(opt, 'sampler_name', post_data['sampler_name'])
25+
setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0)
26+
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
27+
setattr(opt, 'progress_images', 'progress_images' in post_data)
28+
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
29+
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
30+
setattr(opt, 'with_variations', [])
31+
32+
broken = False
33+
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
34+
for part in post_data['with_variations'].split(','):
35+
seed_and_weight = part.split(':')
36+
if len(seed_and_weight) != 2:
37+
print(f'could not parse with_variation part "{part}"')
38+
broken = True
39+
break
40+
try:
41+
seed = int(seed_and_weight[0])
42+
weight = float(seed_and_weight[1])
43+
except ValueError:
44+
print(f'could not parse with_variation part "{part}"')
45+
broken = True
46+
break
47+
opt.with_variations.append([seed, weight])
48+
49+
if broken:
50+
raise CanceledException
51+
52+
if len(opt.with_variations) == 0:
53+
opt.with_variations = None
54+
55+
return opt
56+
57+
class CanceledException(Exception):
58+
pass
59+
60+
class DreamServer(BaseHTTPRequestHandler):
61+
model = None
62+
outdir = None
63+
canceled = Event()
64+
65+
def do_GET(self):
66+
if self.path == "/":
67+
self.send_response(200)
68+
self.send_header("Content-type", "text/html")
69+
self.end_headers()
70+
with open("./static/dream_web/index.html", "rb") as content:
71+
self.wfile.write(content.read())
72+
elif self.path == "/config.js":
73+
# unfortunately this import can't be at the top level, since that would cause a circular import
74+
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
75+
self.send_response(200)
76+
self.send_header("Content-type", "application/javascript")
77+
self.end_headers()
78+
config = {
79+
'gfpgan_model_exists': gfpgan_model_exists
80+
}
81+
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
82+
elif self.path == "/run_log.json":
83+
self.send_response(200)
84+
self.send_header("Content-type", "application/json")
85+
self.end_headers()
86+
output = []
87+
88+
log_file = os.path.join(self.outdir, "dream_web_log.txt")
89+
if os.path.exists(log_file):
90+
with open(log_file, "r") as log:
91+
for line in log:
92+
url, config = line.split(": {", maxsplit=1)
93+
config = json.loads("{" + config)
94+
config["url"] = url.lstrip(".")
95+
if os.path.exists(url):
96+
output.append(config)
97+
98+
self.wfile.write(bytes(json.dumps({"run_log": output}), "utf-8"))
99+
elif self.path == "/cancel":
100+
self.canceled.set()
101+
self.send_response(200)
102+
self.send_header("Content-type", "application/json")
103+
self.end_headers()
104+
self.wfile.write(bytes('{}', 'utf8'))
105+
else:
106+
path = "." + self.path
107+
cwd = os.path.realpath(os.getcwd())
108+
is_in_cwd = os.path.commonprefix((os.path.realpath(path), cwd)) == cwd
109+
if not (is_in_cwd and os.path.exists(path)):
110+
self.send_response(404)
111+
return
112+
mime_type = mimetypes.guess_type(path)[0]
113+
if mime_type is not None:
114+
self.send_response(200)
115+
self.send_header("Content-type", mime_type)
116+
self.end_headers()
117+
with open("." + self.path, "rb") as content:
118+
self.wfile.write(content.read())
119+
else:
120+
self.send_response(404)
121+
122+
def do_POST(self):
123+
self.send_response(200)
124+
self.send_header("Content-type", "application/json")
125+
self.end_headers()
126+
127+
# unfortunately this import can't be at the top level, since that would cause a circular import
128+
# TODO temporarily commented out, import fails for some reason
129+
# from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
130+
gfpgan_model_exists = False
131+
132+
content_length = int(self.headers['Content-Length'])
133+
post_data = json.loads(self.rfile.read(content_length))
134+
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
135+
136+
self.canceled.clear()
137+
print(f">> Request to generate with prompt: {opt.prompt}")
138+
# In order to handle upscaled images, the PngWriter needs to maintain state
139+
# across images generated by each call to prompt2img(), so we define it in
140+
# the outer scope of image_done()
141+
config = post_data.copy() # Shallow copy
142+
config['initimg'] = config.pop('initimg_name', '')
143+
144+
images_generated = 0 # helps keep track of when upscaling is started
145+
images_upscaled = 0 # helps keep track of when upscaling is completed
146+
pngwriter = PngWriter(self.outdir)
147+
148+
prefix = pngwriter.unique_prefix()
149+
# if upscaling is requested, then this will be called twice, once when
150+
# the images are first generated, and then again when after upscaling
151+
# is complete. The upscaling replaces the original file, so the second
152+
# entry should not be inserted into the image list.
153+
def image_done(image, seed, upscaled=False, first_seed=-1, use_prefix=None):
154+
print(f'First seed: {first_seed}')
155+
name = f'{prefix}.{seed}.png'
156+
iter_opt = argparse.Namespace(**vars(opt)) # copy
157+
if opt.variation_amount > 0:
158+
this_variation = [[seed, opt.variation_amount]]
159+
if opt.with_variations is None:
160+
iter_opt.with_variations = this_variation
161+
else:
162+
iter_opt.with_variations = opt.with_variations + this_variation
163+
iter_opt.variation_amount = 0
164+
elif opt.with_variations is None:
165+
iter_opt.seed = seed
166+
normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()
167+
path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)
168+
169+
if int(config['seed']) == -1:
170+
config['seed'] = seed
171+
# Append post_data to log, but only once!
172+
if not upscaled:
173+
with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log:
174+
log.write(f"{path}: {json.dumps(config)}\n")
175+
176+
self.wfile.write(bytes(json.dumps(
177+
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
178+
) + '\n',"utf-8"))
179+
180+
# control state of the "postprocessing..." message
181+
upscaling_requested = opt.upscale or opt.gfpgan_strength > 0
182+
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
183+
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
184+
if upscaled:
185+
images_upscaled += 1
186+
else:
187+
images_generated += 1
188+
if upscaling_requested:
189+
action = None
190+
if images_generated >= opt.iterations:
191+
if images_upscaled < opt.iterations:
192+
action = 'upscaling-started'
193+
else:
194+
action = 'upscaling-done'
195+
if action:
196+
x = images_upscaled + 1
197+
self.wfile.write(bytes(json.dumps(
198+
{'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}
199+
) + '\n',"utf-8"))
200+
201+
step_writer = PngWriter(os.path.join(self.outdir, "intermediates"))
202+
step_index = 1
203+
def image_progress(sample, step):
204+
if self.canceled.is_set():
205+
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
206+
raise CanceledException
207+
path = None
208+
# since rendering images is moderately expensive, only render every 5th image
209+
# and don't bother with the last one, since it'll render anyway
210+
nonlocal step_index
211+
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
212+
image = self.model.sample_to_image(sample)
213+
name = f'{prefix}.{opt.seed}.{step_index}.png'
214+
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
215+
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
216+
step_index += 1
217+
self.wfile.write(bytes(json.dumps(
218+
{'event': 'step', 'step': step + 1, 'url': path}
219+
) + '\n',"utf-8"))
220+
221+
try:
222+
if opt.init_img is None:
223+
# Run txt2img
224+
self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)
225+
else:
226+
# Decode initimg as base64 to temp file
227+
with open("./img2img-tmp.png", "wb") as f:
228+
initimg = opt.init_img.split(",")[1] # Ignore mime type
229+
f.write(base64.b64decode(initimg))
230+
opt1 = argparse.Namespace(**vars(opt))
231+
opt1.init_img = "./img2img-tmp.png"
232+
233+
try:
234+
# Run img2img
235+
self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)
236+
finally:
237+
# Remove the temp file
238+
os.remove("./img2img-tmp.png")
239+
except CanceledException:
240+
print(f"Canceled.")
241+
return
242+
243+
244+
class ThreadingDreamServer(ThreadingHTTPServer):
245+
def __init__(self, server_address):
246+
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)

0 commit comments

Comments
 (0)