Skip to content

Commit 9bb6b65

Browse files
committed
add postprocess call for scripts
1 parent 35c45df commit 9bb6b65

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

modules/processing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def infotext(iteration=0, position_in_batch=0):
478478
model_hijack.embedding_db.load_textual_inversion_embeddings()
479479

480480
if p.scripts is not None:
481-
p.scripts.run_alwayson_scripts(p)
481+
p.scripts.process(p)
482482

483483
infotexts = []
484484
output_images = []
@@ -501,7 +501,7 @@ def infotext(iteration=0, position_in_batch=0):
501501
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
502502
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
503503

504-
if (len(prompts) == 0):
504+
if len(prompts) == 0:
505505
break
506506

507507
with devices.autocast():
@@ -590,7 +590,13 @@ def infotext(iteration=0, position_in_batch=0):
590590
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
591591

592592
devices.torch_gc()
593-
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
593+
594+
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
595+
596+
if p.scripts is not None:
597+
p.scripts.postprocess(p, res)
598+
599+
return res
594600

595601

596602
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

modules/scripts.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,16 @@ def run(self, p, *args):
6464
def process(self, p, *args):
6565
"""
6666
This function is called before processing begins for AlwaysVisible scripts.
67-
scripts. You can modify the processing object (p) here, inject hooks, etc.
67+
You can modify the processing object (p) here, inject hooks, etc.
68+
args contains all values returned by components from ui()
69+
"""
70+
71+
pass
72+
73+
def postprocess(self, p, processed, *args):
74+
"""
75+
This function is called after processing ends for AlwaysVisible scripts.
76+
args contains all values returned by components from ui()
6877
"""
6978

7079
pass
@@ -289,13 +298,22 @@ def run(self, p: StableDiffusionProcessing, *args):
289298

290299
return processed
291300

292-
def run_alwayson_scripts(self, p):
301+
def process(self, p):
293302
for script in self.alwayson_scripts:
294303
try:
295304
script_args = p.script_args[script.args_from:script.args_to]
296305
script.process(p, *script_args)
297306
except Exception:
298-
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
307+
print(f"Error running process: {script.filename}", file=sys.stderr)
308+
print(traceback.format_exc(), file=sys.stderr)
309+
310+
def postprocess(self, p, processed):
311+
for script in self.alwayson_scripts:
312+
try:
313+
script_args = p.script_args[script.args_from:script.args_to]
314+
script.postprocess(p, processed, *script_args)
315+
except Exception:
316+
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
299317
print(traceback.format_exc(), file=sys.stderr)
300318

301319
def reload_sources(self, cache):

0 commit comments

Comments
 (0)