1- import numpy as np
21from pytorch_lightning import seed_everything
32
43from scripts .demo .streamlit_helpers import *
5- from scripts .util .detection .nsfw_and_watermark_dectection import DeepFloydDataFiltering
6- from sgm .inference .helpers import (
7- do_img2img ,
8- do_sample ,
9- get_unique_embedder_keys_from_conditioner ,
10- perform_save_locally ,
11- )
124
135SAVE_PATH = "outputs/demo/txt2img/"
146
4234}
4335
4436VERSION2SPECS = {
45- "SD-XL base" : {
37+ "SDXL-base-1.0" : {
38+ "H" : 1024 ,
39+ "W" : 1024 ,
40+ "C" : 4 ,
41+ "f" : 8 ,
42+ "is_legacy" : False ,
43+ "config" : "configs/inference/sd_xl_base.yaml" ,
44+ "ckpt" : "checkpoints/sd_xl_base_1.0.safetensors" ,
45+ },
46+ "SDXL-base-0.9" : {
4647 "H" : 1024 ,
4748 "W" : 1024 ,
4849 "C" : 4 ,
4950 "f" : 8 ,
5051 "is_legacy" : False ,
5152 "config" : "configs/inference/sd_xl_base.yaml" ,
5253 "ckpt" : "checkpoints/sd_xl_base_0.9.safetensors" ,
53- "is_guided" : True ,
5454 },
55- "sd -2.1" : {
55+ "SD -2.1" : {
5656 "H" : 512 ,
5757 "W" : 512 ,
5858 "C" : 4 ,
5959 "f" : 8 ,
6060 "is_legacy" : True ,
6161 "config" : "configs/inference/sd_2_1.yaml" ,
6262 "ckpt" : "checkpoints/v2-1_512-ema-pruned.safetensors" ,
63- "is_guided" : True ,
6463 },
65- "sd -2.1-768" : {
64+ "SD -2.1-768" : {
6665 "H" : 768 ,
6766 "W" : 768 ,
6867 "C" : 4 ,
7170 "config" : "configs/inference/sd_2_1_768.yaml" ,
7271 "ckpt" : "checkpoints/v2-1_768-ema-pruned.safetensors" ,
7372 },
74- "SDXL-Refiner " : {
73+ "SDXL-refiner-0.9 " : {
7574 "H" : 1024 ,
7675 "W" : 1024 ,
7776 "C" : 4 ,
7877 "f" : 8 ,
7978 "is_legacy" : True ,
8079 "config" : "configs/inference/sd_xl_refiner.yaml" ,
8180 "ckpt" : "checkpoints/sd_xl_refiner_0.9.safetensors" ,
82- "is_guided" : True ,
81+ },
82+ "SDXL-refiner-1.0" : {
83+ "H" : 1024 ,
84+ "W" : 1024 ,
85+ "C" : 4 ,
86+ "f" : 8 ,
87+ "is_legacy" : True ,
88+ "config" : "configs/inference/sd_xl_refiner.yaml" ,
89+ "ckpt" : "checkpoints/sd_xl_refiner_1.0.safetensors" ,
8390 },
8491}
8592
@@ -103,18 +110,19 @@ def load_img(display=True, key=None, device="cuda"):
103110
104111
105112def run_txt2img (
106- state , version , version_dict , is_legacy = False , return_latents = False , filter = None
113+ state ,
114+ version ,
115+ version_dict ,
116+ is_legacy = False ,
117+ return_latents = False ,
118+ filter = None ,
119+ stage2strength = None ,
107120):
108- if version == "SD-XL base" :
109- ratio = st .sidebar .selectbox ("Ratio:" , list (SD_XL_BASE_RATIOS .keys ()), 10 )
110- W , H = SD_XL_BASE_RATIOS [ratio ]
121+ if version .startswith ("SDXL-base" ):
122+ W , H = st .selectbox ("Resolution:" , list (SD_XL_BASE_RATIOS .values ()), 10 )
111123 else :
112- H = st .sidebar .number_input (
113- "H" , value = version_dict ["H" ], min_value = 64 , max_value = 2048
114- )
115- W = st .sidebar .number_input (
116- "W" , value = version_dict ["W" ], min_value = 64 , max_value = 2048
117- )
124+ H = st .number_input ("H" , value = version_dict ["H" ], min_value = 64 , max_value = 2048 )
125+ W = st .number_input ("W" , value = version_dict ["W" ], min_value = 64 , max_value = 2048 )
118126 C = version_dict ["C" ]
119127 F = version_dict ["f" ]
120128
@@ -130,16 +138,11 @@ def run_txt2img(
130138 prompt = prompt ,
131139 negative_prompt = negative_prompt ,
132140 )
133- num_rows , num_cols , sampler = init_sampling (
134- use_identity_guider = not version_dict ["is_guided" ]
135- )
136-
141+ sampler , num_rows , num_cols = init_sampling (stage2strength = stage2strength )
137142 num_samples = num_rows * num_cols
138143
139144 if st .button ("Sample" ):
140145 st .write (f"**Model I:** { version } " )
141- outputs = st .empty ()
142- st .text ("Sampling" )
143146 out = do_sample (
144147 state ["model" ],
145148 sampler ,
@@ -153,13 +156,16 @@ def run_txt2img(
153156 return_latents = return_latents ,
154157 filter = filter ,
155158 )
156- show_samples (out , outputs )
157-
158159 return out
159160
160161
161162def run_img2img (
162- state , version_dict , is_legacy = False , return_latents = False , filter = None
163+ state ,
164+ version_dict ,
165+ is_legacy = False ,
166+ return_latents = False ,
167+ filter = None ,
168+ stage2strength = None ,
163169):
164170 img = load_img ()
165171 if img is None :
@@ -175,19 +181,19 @@ def run_img2img(
175181 value_dict = init_embedder_options (
176182 get_unique_embedder_keys_from_conditioner (state ["model" ].conditioner ),
177183 init_dict ,
184+ prompt = prompt ,
185+ negative_prompt = negative_prompt ,
178186 )
179187 strength = st .number_input (
180- "**Img2Img Strength**" , value = 0.5 , min_value = 0.0 , max_value = 1.0
188+ "**Img2Img Strength**" , value = 0.75 , min_value = 0.0 , max_value = 1.0
181189 )
182- num_rows , num_cols , sampler = init_sampling (
190+ sampler , num_rows , num_cols = init_sampling (
183191 img2img_strength = strength ,
184- use_identity_guider = not version_dict [ "is_guided" ] ,
192+ stage2strength = stage2strength ,
185193 )
186194 num_samples = num_rows * num_cols
187195
188196 if st .button ("Sample" ):
189- outputs = st .empty ()
190- st .text ("Sampling" )
191197 out = do_img2img (
192198 repeat (img , "1 ... -> n ..." , n = num_samples ),
193199 state ["model" ],
@@ -198,7 +204,6 @@ def run_img2img(
198204 return_latents = return_latents ,
199205 filter = filter ,
200206 )
201- show_samples (out , outputs )
202207 return out
203208
204209
@@ -210,6 +215,7 @@ def apply_refiner(
210215 prompt ,
211216 negative_prompt ,
212217 filter = None ,
218+ finish_denoising = False ,
213219):
214220 init_dict = {
215221 "orig_width" : input .shape [3 ] * 8 ,
@@ -237,6 +243,7 @@ def apply_refiner(
237243 num_samples ,
238244 skip_encode = True ,
239245 filter = filter ,
246+ add_noise = not finish_denoising ,
240247 )
241248
242249 return samples
@@ -249,20 +256,22 @@ def apply_refiner(
249256 mode = st .radio ("Mode" , ("txt2img" , "img2img" ), 0 )
250257 st .write ("__________________________" )
251258
252- if version == "SD-XL base" :
253- add_pipeline = st .checkbox ("Load SDXL-Refiner?" , False )
259+ set_lowvram_mode (st .checkbox ("Low vram mode" , True ))
260+
261+ if version .startswith ("SDXL-base" ):
262+ add_pipeline = st .checkbox ("Load SDXL-refiner?" , False )
254263 st .write ("__________________________" )
255264 else :
256265 add_pipeline = False
257266
258- filter = DeepFloydDataFiltering (verbose = False )
259-
260267 seed = st .sidebar .number_input ("seed" , value = 42 , min_value = 0 , max_value = int (1e9 ))
261268 seed_everything (seed )
262269
263270 save_locally , save_path = init_save_locally (os .path .join (SAVE_PATH , version ))
264271
265- state = init_st (version_dict )
272+ state = init_st (version_dict , load_filter = True )
273+ if state ["msg" ]:
274+ st .info (state ["msg" ])
266275 model = state ["model" ]
267276
268277 is_legacy = version_dict ["is_legacy" ]
@@ -276,29 +285,34 @@ def apply_refiner(
276285 else :
277286 negative_prompt = "" # which is unused
278287
288+ stage2strength = None
289+ finish_denoising = False
290+
279291 if add_pipeline :
280292 st .write ("__________________________" )
281-
282- version2 = "SDXL-Refiner"
293+ version2 = st .selectbox ("Refiner:" , ["SDXL-refiner-1.0" , "SDXL-refiner-0.9" ])
283294 st .warning (
284295 f"Running with { version2 } as the second stage model. Make sure to provide (V)RAM :) "
285296 )
286297 st .write ("**Refiner Options:**" )
287298
288299 version_dict2 = VERSION2SPECS [version2 ]
289- state2 = init_st (version_dict2 )
300+ state2 = init_st (version_dict2 , load_filter = False )
301+ st .info (state2 ["msg" ])
290302
291303 stage2strength = st .number_input (
292- "**Refinement strength**" , value = 0.3 , min_value = 0.0 , max_value = 1.0
304+ "**Refinement strength**" , value = 0.15 , min_value = 0.0 , max_value = 1.0
293305 )
294306
295- sampler2 = init_sampling (
307+ sampler2 , * _ = init_sampling (
296308 key = 2 ,
297309 img2img_strength = stage2strength ,
298- use_identity_guider = not version_dict2 ["is_guided" ],
299- get_num_samples = False ,
310+ specify_num_samples = False ,
300311 )
301312 st .write ("__________________________" )
313+ finish_denoising = st .checkbox ("Finish denoising with refiner." , True )
314+ if not finish_denoising :
315+ stage2strength = None
302316
303317 if mode == "txt2img" :
304318 out = run_txt2img (
@@ -307,15 +321,17 @@ def apply_refiner(
307321 version_dict ,
308322 is_legacy = is_legacy ,
309323 return_latents = add_pipeline ,
310- filter = filter ,
324+ filter = state .get ("filter" ),
325+ stage2strength = stage2strength ,
311326 )
312327 elif mode == "img2img" :
313328 out = run_img2img (
314329 state ,
315330 version_dict ,
316331 is_legacy = is_legacy ,
317332 return_latents = add_pipeline ,
318- filter = filter ,
333+ filter = state .get ("filter" ),
334+ stage2strength = stage2strength ,
319335 )
320336 else :
321337 raise ValueError (f"unknown mode { mode } " )
@@ -326,7 +342,6 @@ def apply_refiner(
326342 samples_z = None
327343
328344 if add_pipeline and samples_z is not None :
329- outputs = st .empty ()
330345 st .write ("**Running Refinement Stage**" )
331346 samples = apply_refiner (
332347 samples_z ,
@@ -335,9 +350,9 @@ def apply_refiner(
335350 samples_z .shape [0 ],
336351 prompt = prompt ,
337352 negative_prompt = negative_prompt if is_legacy else "" ,
338- filter = filter ,
353+ filter = state .get ("filter" ),
354+ finish_denoising = finish_denoising ,
339355 )
340- show_samples (samples , outputs )
341356
342357 if save_locally and samples is not None :
343358 perform_save_locally (save_path , samples )
0 commit comments