Skip to content

Commit fddb488

Browse files
committed
prototype progress api
1 parent 99d728b commit fddb488

File tree

2 files changed

+88
-14
lines changed

2 files changed

+88
-14
lines changed

modules/api/api.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import time
2+
13
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
24
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
35
from modules.sd_samplers import all_samplers
46
from modules.extras import run_pnginfo
57
import modules.shared as shared
8+
from modules import devices
69
import uvicorn
710
from fastapi import Body, APIRouter, HTTPException
811
from fastapi.responses import JSONResponse
@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
2528
parameters: Json
2629
info: Json
2730

31+
class ProgressResponse(BaseModel):
32+
progress: float
33+
eta_relative: float
34+
state: Json
35+
36+
# copy from wrap_gradio_gpu_call of webui.py
37+
# because queue lock will be acquired in api handlers
38+
# and time start needs to be set
39+
# the function has been modified into two parts
40+
41+
def before_gpu_call():
42+
devices.torch_gc()
43+
44+
shared.state.sampling_step = 0
45+
shared.state.job_count = -1
46+
shared.state.job_no = 0
47+
shared.state.job_timestamp = shared.state.get_job_timestamp()
48+
shared.state.current_latent = None
49+
shared.state.current_image = None
50+
shared.state.current_image_sampling_step = 0
51+
shared.state.skipped = False
52+
shared.state.interrupted = False
53+
shared.state.textinfo = None
54+
shared.state.time_start = time.time()
55+
56+
57+
def after_gpu_call():
58+
shared.state.job = ""
59+
shared.state.job_count = 0
60+
61+
devices.torch_gc()
2862

2963
class Api:
3064
def __init__(self, app, queue_lock):
@@ -33,6 +67,7 @@ def __init__(self, app, queue_lock):
3367
self.queue_lock = queue_lock
3468
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
3569
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
70+
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
3671

3772
def __base64_to_image(self, base64_string):
3873
# if has a comma, deal with prefix
@@ -44,53 +79,55 @@ def __base64_to_image(self, base64_string):
4479

4580
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
4681
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
47-
82+
4883
if sampler_index is None:
49-
raise HTTPException(status_code=404, detail="Sampler not found")
50-
84+
raise HTTPException(status_code=404, detail="Sampler not found")
85+
5186
populate = txt2imgreq.copy(update={ # Override __init__ params
52-
"sd_model": shared.sd_model,
87+
"sd_model": shared.sd_model,
5388
"sampler_index": sampler_index[0],
5489
"do_not_save_samples": True,
5590
"do_not_save_grid": True
5691
}
5792
)
5893
p = StableDiffusionProcessingTxt2Img(**vars(populate))
5994
# Override object param
95+
before_gpu_call()
6096
with self.queue_lock:
6197
processed = process_images(p)
62-
98+
after_gpu_call()
99+
63100
b64images = []
64101
for i in processed.images:
65102
buffer = io.BytesIO()
66103
i.save(buffer, format="png")
67104
b64images.append(base64.b64encode(buffer.getvalue()))
68105

69106
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
70-
71-
107+
108+
72109

73110
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
74111
sampler_index = sampler_to_index(img2imgreq.sampler_index)
75-
112+
76113
if sampler_index is None:
77-
raise HTTPException(status_code=404, detail="Sampler not found")
114+
raise HTTPException(status_code=404, detail="Sampler not found")
78115

79116

80117
init_images = img2imgreq.init_images
81118
if init_images is None:
82-
raise HTTPException(status_code=404, detail="Init image not found")
119+
raise HTTPException(status_code=404, detail="Init image not found")
83120

84121
mask = img2imgreq.mask
85122
if mask:
86123
mask = self.__base64_to_image(mask)
87124

88-
125+
89126
populate = img2imgreq.copy(update={ # Override __init__ params
90-
"sd_model": shared.sd_model,
127+
"sd_model": shared.sd_model,
91128
"sampler_index": sampler_index[0],
92129
"do_not_save_samples": True,
93-
"do_not_save_grid": True,
130+
"do_not_save_grid": True,
94131
"mask": mask
95132
}
96133
)
@@ -103,9 +140,11 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
103140

104141
p.init_images = imgs
105142
# Override object param
143+
before_gpu_call()
106144
with self.queue_lock:
107145
processed = process_images(p)
108-
146+
after_gpu_call()
147+
109148
b64images = []
110149
for i in processed.images:
111150
buffer = io.BytesIO()
@@ -118,6 +157,28 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
118157

119158
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
120159

160+
def progressapi(self):
161+
# copy from check_progress_call of ui.py
162+
163+
if shared.state.job_count == 0:
164+
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
165+
166+
# avoid dividing zero
167+
progress = 0.01
168+
169+
if shared.state.job_count > 0:
170+
progress += shared.state.job_no / shared.state.job_count
171+
if shared.state.sampling_steps > 0:
172+
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
173+
174+
time_since_start = time.time() - shared.state.time_start
175+
eta = (time_since_start/progress)
176+
eta_relative = eta-time_since_start
177+
178+
progress = min(progress, 1)
179+
180+
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
181+
121182
def extrasapi(self):
122183
raise NotImplementedError
123184

modules/shared.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,19 @@ def nextjob(self):
146146
def get_job_timestamp(self):
147147
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
148148

149+
def js(self):
150+
obj = {
151+
"skipped": self.skipped,
152+
"interrupted": self.skipped,
153+
"job": self.job,
154+
"job_count": self.job_count,
155+
"job_no": self.job_no,
156+
"sampling_step": self.sampling_step,
157+
"sampling_steps": self.sampling_steps,
158+
}
159+
160+
return json.dumps(obj)
161+
149162

150163
state = State()
151164

0 commit comments

Comments
 (0)