1
+ import time
1
2
import uvicorn
2
3
from gradio .processing_utils import encode_pil_to_base64 , decode_base64_to_file , decode_base64_to_image
3
- from fastapi import APIRouter , HTTPException
4
+ from fastapi import APIRouter , Depends , HTTPException
4
5
import modules .shared as shared
6
+ from modules import devices
5
7
from modules .api .models import *
6
8
from modules .processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
7
9
from modules .sd_samplers import all_samplers
8
10
from modules .extras import run_extras , run_pnginfo
9
11
12
+ # copy from wrap_gradio_gpu_call of webui.py
13
+ # because queue lock will be acquired in api handlers
14
+ # and time start needs to be set
15
+ # the function has been modified into two parts
16
+
17
+ def before_gpu_call ():
18
+ devices .torch_gc ()
19
+
20
+ shared .state .sampling_step = 0
21
+ shared .state .job_count = - 1
22
+ shared .state .job_no = 0
23
+ shared .state .job_timestamp = shared .state .get_job_timestamp ()
24
+ shared .state .current_latent = None
25
+ shared .state .current_image = None
26
+ shared .state .current_image_sampling_step = 0
27
+ shared .state .skipped = False
28
+ shared .state .interrupted = False
29
+ shared .state .textinfo = None
30
+ shared .state .time_start = time .time ()
31
+
32
+ def after_gpu_call ():
33
+ shared .state .job = ""
34
+ shared .state .job_count = 0
35
+
36
+ devices .torch_gc ()
37
+
10
38
def upscaler_to_index (name : str ):
11
39
try :
12
40
return [x .name .lower () for x in shared .sd_upscalers ].index (name .lower ())
@@ -33,50 +61,53 @@ def __init__(self, app, queue_lock):
33
61
self .app .add_api_route ("/sdapi/v1/extra-single-image" , self .extras_single_image_api , methods = ["POST" ], response_model = ExtrasSingleImageResponse )
34
62
self .app .add_api_route ("/sdapi/v1/extra-batch-images" , self .extras_batch_images_api , methods = ["POST" ], response_model = ExtrasBatchImagesResponse )
35
63
self .app .add_api_route ("/sdapi/v1/png-info" , self .pnginfoapi , methods = ["POST" ], response_model = PNGInfoResponse )
64
+ self .app .add_api_route ("/sdapi/v1/progress" , self .progressapi , methods = ["GET" ], response_model = ProgressResponse )
36
65
37
66
def text2imgapi (self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ):
38
67
sampler_index = sampler_to_index (txt2imgreq .sampler_index )
39
-
68
+
40
69
if sampler_index is None :
41
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
42
-
70
+ raise HTTPException (status_code = 404 , detail = "Sampler not found" )
71
+
43
72
populate = txt2imgreq .copy (update = { # Override __init__ params
44
- "sd_model" : shared .sd_model ,
73
+ "sd_model" : shared .sd_model ,
45
74
"sampler_index" : sampler_index [0 ],
46
75
"do_not_save_samples" : True ,
47
76
"do_not_save_grid" : True
48
77
}
49
78
)
50
79
p = StableDiffusionProcessingTxt2Img (** vars (populate ))
51
80
# Override object param
81
+ before_gpu_call ()
52
82
with self .queue_lock :
53
83
processed = process_images (p )
54
-
84
+ after_gpu_call ()
85
+
55
86
b64images = list (map (encode_pil_to_base64 , processed .images ))
56
-
87
+
57
88
return TextToImageResponse (images = b64images , parameters = vars (txt2imgreq ), info = processed .js ())
58
89
59
90
def img2imgapi (self , img2imgreq : StableDiffusionImg2ImgProcessingAPI ):
60
91
sampler_index = sampler_to_index (img2imgreq .sampler_index )
61
-
92
+
62
93
if sampler_index is None :
63
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
94
+ raise HTTPException (status_code = 404 , detail = "Sampler not found" )
64
95
65
96
66
97
init_images = img2imgreq .init_images
67
98
if init_images is None :
68
- raise HTTPException (status_code = 404 , detail = "Init image not found" )
99
+ raise HTTPException (status_code = 404 , detail = "Init image not found" )
69
100
70
101
mask = img2imgreq .mask
71
102
if mask :
72
103
mask = decode_base64_to_image (mask )
73
104
74
-
105
+
75
106
populate = img2imgreq .copy (update = { # Override __init__ params
76
- "sd_model" : shared .sd_model ,
107
+ "sd_model" : shared .sd_model ,
77
108
"sampler_index" : sampler_index [0 ],
78
109
"do_not_save_samples" : True ,
79
- "do_not_save_grid" : True ,
110
+ "do_not_save_grid" : True ,
80
111
"mask" : mask
81
112
}
82
113
)
@@ -89,15 +120,17 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
89
120
90
121
p .init_images = imgs
91
122
# Override object param
123
+ before_gpu_call ()
92
124
with self .queue_lock :
93
125
processed = process_images (p )
94
-
126
+ after_gpu_call ()
127
+
95
128
b64images = list (map (encode_pil_to_base64 , processed .images ))
96
129
97
130
if (not img2imgreq .include_init_images ):
98
131
img2imgreq .init_images = None
99
132
img2imgreq .mask = None
100
-
133
+
101
134
return ImageToImageResponse (images = b64images , parameters = vars (img2imgreq ), info = processed .js ())
102
135
103
136
def extras_single_image_api (self , req : ExtrasSingleImageRequest ):
@@ -125,7 +158,7 @@ def prepareFiles(file):
125
158
result = run_extras (extras_mode = 1 , image = "" , input_dir = "" , output_dir = "" , ** reqDict )
126
159
127
160
return ExtrasBatchImagesResponse (images = list (map (encode_pil_to_base64 , result [0 ])), html_info = result [1 ])
128
-
161
+
129
162
def pnginfoapi (self , req : PNGInfoRequest ):
130
163
if (not req .image .strip ()):
131
164
return PNGInfoResponse (info = "" )
@@ -134,6 +167,32 @@ def pnginfoapi(self, req: PNGInfoRequest):
134
167
135
168
return PNGInfoResponse (info = result [1 ])
136
169
170
+ def progressapi (self , req : ProgressRequest = Depends ()):
171
+ # copy from check_progress_call of ui.py
172
+
173
+ if shared .state .job_count == 0 :
174
+ return ProgressResponse (progress = 0 , eta_relative = 0 , state = shared .state .dict ())
175
+
176
+ # avoid dividing zero
177
+ progress = 0.01
178
+
179
+ if shared .state .job_count > 0 :
180
+ progress += shared .state .job_no / shared .state .job_count
181
+ if shared .state .sampling_steps > 0 :
182
+ progress += 1 / shared .state .job_count * shared .state .sampling_step / shared .state .sampling_steps
183
+
184
+ time_since_start = time .time () - shared .state .time_start
185
+ eta = (time_since_start / progress )
186
+ eta_relative = eta - time_since_start
187
+
188
+ progress = min (progress , 1 )
189
+
190
+ current_image = None
191
+ if shared .state .current_image and not req .skip_current_image :
192
+ current_image = encode_pil_to_base64 (shared .state .current_image )
193
+
194
+ return ProgressResponse (progress = progress , eta_relative = eta_relative , state = shared .state .dict (), current_image = current_image )
195
+
137
196
def launch (self , server_name , port ):
138
197
self .app .include_router (self .router )
139
198
uvicorn .run (self .app , host = server_name , port = port )
0 commit comments