9
9
from modules .sd_samplers import all_samplers
10
10
from modules .extras import run_extras , run_pnginfo
11
11
12
- # copy from wrap_gradio_gpu_call of webui.py
13
- # because queue lock will be acquired in api handlers
14
- # and time start needs to be set
15
- # the function has been modified into two parts
16
-
17
- def before_gpu_call ():
18
- devices .torch_gc ()
19
-
20
- shared .state .sampling_step = 0
21
- shared .state .job_count = - 1
22
- shared .state .job_no = 0
23
- shared .state .job_timestamp = shared .state .get_job_timestamp ()
24
- shared .state .current_latent = None
25
- shared .state .current_image = None
26
- shared .state .current_image_sampling_step = 0
27
- shared .state .skipped = False
28
- shared .state .interrupted = False
29
- shared .state .textinfo = None
30
- shared .state .time_start = time .time ()
31
-
32
- def after_gpu_call ():
33
- shared .state .job = ""
34
- shared .state .job_count = 0
35
-
36
- devices .torch_gc ()
37
12
38
13
def upscaler_to_index (name : str ):
39
14
try :
40
15
return [x .name .lower () for x in shared .sd_upscalers ].index (name .lower ())
41
16
except :
42
17
raise HTTPException (status_code = 400 , detail = f"Invalid upscaler, needs to be on of these: { ' , ' .join ([x .name for x in sd_upscalers ])} " )
43
18
19
+
44
20
sampler_to_index = lambda name : next (filter (lambda row : name .lower () == row [1 ].name .lower (), enumerate (all_samplers )), None )
45
21
22
+
46
23
def setUpscalers (req : dict ):
47
24
reqDict = vars (req )
48
25
reqDict ['extras_upscaler_1' ] = upscaler_to_index (req .upscaler_1 )
@@ -51,6 +28,7 @@ def setUpscalers(req: dict):
51
28
reqDict .pop ('upscaler_2' )
52
29
return reqDict
53
30
31
+
54
32
class Api :
55
33
def __init__ (self , app , queue_lock ):
56
34
self .router = APIRouter ()
@@ -78,10 +56,13 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
78
56
)
79
57
p = StableDiffusionProcessingTxt2Img (** vars (populate ))
80
58
# Override object param
81
- before_gpu_call ()
59
+
60
+ shared .state .begin ()
61
+
82
62
with self .queue_lock :
83
63
processed = process_images (p )
84
- after_gpu_call ()
64
+
65
+ shared .state .end ()
85
66
86
67
b64images = list (map (encode_pil_to_base64 , processed .images ))
87
68
@@ -119,11 +100,13 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
119
100
imgs = [img ] * p .batch_size
120
101
121
102
p .init_images = imgs
122
- # Override object param
123
- before_gpu_call ()
103
+
104
+ shared .state .begin ()
105
+
124
106
with self .queue_lock :
125
107
processed = process_images (p )
126
- after_gpu_call ()
108
+
109
+ shared .state .end ()
127
110
128
111
b64images = list (map (encode_pil_to_base64 , processed .images ))
129
112
0 commit comments