1
+ import time
2
+
1
3
from modules .api .models import StableDiffusionTxt2ImgProcessingAPI , StableDiffusionImg2ImgProcessingAPI
2
4
from modules .processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
3
5
from modules .sd_samplers import all_samplers
4
6
from modules .extras import run_pnginfo
5
7
import modules .shared as shared
8
+ from modules import devices
6
9
import uvicorn
7
10
from fastapi import Body , APIRouter , HTTPException
8
11
from fastapi .responses import JSONResponse
@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
25
28
parameters : Json
26
29
info : Json
27
30
31
+ class ProgressResponse (BaseModel ):
32
+ progress : float
33
+ eta_relative : float
34
+ state : Json
35
+
36
+ # copy from wrap_gradio_gpu_call of webui.py
37
+ # because queue lock will be acquired in api handlers
38
+ # and time start needs to be set
39
+ # the function has been modified into two parts
40
+
41
+ def before_gpu_call ():
42
+ devices .torch_gc ()
43
+
44
+ shared .state .sampling_step = 0
45
+ shared .state .job_count = - 1
46
+ shared .state .job_no = 0
47
+ shared .state .job_timestamp = shared .state .get_job_timestamp ()
48
+ shared .state .current_latent = None
49
+ shared .state .current_image = None
50
+ shared .state .current_image_sampling_step = 0
51
+ shared .state .skipped = False
52
+ shared .state .interrupted = False
53
+ shared .state .textinfo = None
54
+ shared .state .time_start = time .time ()
55
+
56
+
57
+ def after_gpu_call ():
58
+ shared .state .job = ""
59
+ shared .state .job_count = 0
60
+
61
+ devices .torch_gc ()
28
62
29
63
class Api :
30
64
def __init__ (self , app , queue_lock ):
@@ -33,6 +67,7 @@ def __init__(self, app, queue_lock):
33
67
self .queue_lock = queue_lock
34
68
self .app .add_api_route ("/sdapi/v1/txt2img" , self .text2imgapi , methods = ["POST" ])
35
69
self .app .add_api_route ("/sdapi/v1/img2img" , self .img2imgapi , methods = ["POST" ])
70
+ self .app .add_api_route ("/sdapi/v1/progress" , self .progressapi , methods = ["GET" ])
36
71
37
72
def __base64_to_image (self , base64_string ):
38
73
# if has a comma, deal with prefix
@@ -44,53 +79,55 @@ def __base64_to_image(self, base64_string):
44
79
45
80
def text2imgapi (self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ):
46
81
sampler_index = sampler_to_index (txt2imgreq .sampler_index )
47
-
82
+
48
83
if sampler_index is None :
49
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
50
-
84
+ raise HTTPException (status_code = 404 , detail = "Sampler not found" )
85
+
51
86
populate = txt2imgreq .copy (update = { # Override __init__ params
52
- "sd_model" : shared .sd_model ,
87
+ "sd_model" : shared .sd_model ,
53
88
"sampler_index" : sampler_index [0 ],
54
89
"do_not_save_samples" : True ,
55
90
"do_not_save_grid" : True
56
91
}
57
92
)
58
93
p = StableDiffusionProcessingTxt2Img (** vars (populate ))
59
94
# Override object param
95
+ before_gpu_call ()
60
96
with self .queue_lock :
61
97
processed = process_images (p )
62
-
98
+ after_gpu_call ()
99
+
63
100
b64images = []
64
101
for i in processed .images :
65
102
buffer = io .BytesIO ()
66
103
i .save (buffer , format = "png" )
67
104
b64images .append (base64 .b64encode (buffer .getvalue ()))
68
105
69
106
return TextToImageResponse (images = b64images , parameters = json .dumps (vars (txt2imgreq )), info = processed .js ())
70
-
71
-
107
+
108
+
72
109
73
110
def img2imgapi (self , img2imgreq : StableDiffusionImg2ImgProcessingAPI ):
74
111
sampler_index = sampler_to_index (img2imgreq .sampler_index )
75
-
112
+
76
113
if sampler_index is None :
77
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
114
+ raise HTTPException (status_code = 404 , detail = "Sampler not found" )
78
115
79
116
80
117
init_images = img2imgreq .init_images
81
118
if init_images is None :
82
- raise HTTPException (status_code = 404 , detail = "Init image not found" )
119
+ raise HTTPException (status_code = 404 , detail = "Init image not found" )
83
120
84
121
mask = img2imgreq .mask
85
122
if mask :
86
123
mask = self .__base64_to_image (mask )
87
124
88
-
125
+
89
126
populate = img2imgreq .copy (update = { # Override __init__ params
90
- "sd_model" : shared .sd_model ,
127
+ "sd_model" : shared .sd_model ,
91
128
"sampler_index" : sampler_index [0 ],
92
129
"do_not_save_samples" : True ,
93
- "do_not_save_grid" : True ,
130
+ "do_not_save_grid" : True ,
94
131
"mask" : mask
95
132
}
96
133
)
@@ -103,9 +140,11 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
103
140
104
141
p .init_images = imgs
105
142
# Override object param
143
+ before_gpu_call ()
106
144
with self .queue_lock :
107
145
processed = process_images (p )
108
-
146
+ after_gpu_call ()
147
+
109
148
b64images = []
110
149
for i in processed .images :
111
150
buffer = io .BytesIO ()
@@ -118,6 +157,28 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
118
157
119
158
return ImageToImageResponse (images = b64images , parameters = json .dumps (vars (img2imgreq )), info = processed .js ())
120
159
160
+ def progressapi (self ):
161
+ # copy from check_progress_call of ui.py
162
+
163
+ if shared .state .job_count == 0 :
164
+ return ProgressResponse (progress = 0 , eta_relative = 0 , state = shared .state .js ())
165
+
166
+ # avoid dividing zero
167
+ progress = 0.01
168
+
169
+ if shared .state .job_count > 0 :
170
+ progress += shared .state .job_no / shared .state .job_count
171
+ if shared .state .sampling_steps > 0 :
172
+ progress += 1 / shared .state .job_count * shared .state .sampling_step / shared .state .sampling_steps
173
+
174
+ time_since_start = time .time () - shared .state .time_start
175
+ eta = (time_since_start / progress )
176
+ eta_relative = eta - time_since_start
177
+
178
+ progress = min (progress , 1 )
179
+
180
+ return ProgressResponse (progress = progress , eta_relative = eta_relative , state = shared .state .js ())
181
+
121
182
def extrasapi (self ):
122
183
raise NotImplementedError
123
184
0 commit comments