9
9
from secrets import compare_digest
10
10
11
11
import modules .shared as shared
12
+ from modules import sd_samplers
12
13
from modules .api .models import *
13
14
from modules .processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
14
- from modules .sd_samplers import all_samplers
15
15
from modules .extras import run_extras , run_pnginfo
16
16
from PIL import PngImagePlugin
17
17
from modules .sd_models import checkpoints_list
@@ -28,8 +28,12 @@ def upscaler_to_index(name: str):
28
28
raise HTTPException (status_code = 400 , detail = f"Invalid upscaler, needs to be on of these: { ' , ' .join ([x .name for x in sd_upscalers ])} " )
29
29
30
30
31
- sampler_to_index = lambda name : next (filter (lambda row : name .lower () == row [1 ].name .lower (), enumerate (all_samplers )), None )
31
+ def validate_sampler_name (name ):
32
+ config = sd_samplers .all_samplers_map .get (name , None )
33
+ if config is None :
34
+ raise HTTPException (status_code = 404 , detail = "Sampler not found" )
32
35
36
+ return name
33
37
34
38
def setUpscalers (req : dict ):
35
39
reqDict = vars (req )
@@ -77,6 +81,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
77
81
self .add_api_route ("/sdapi/v1/progress" , self .progressapi , methods = ["GET" ], response_model = ProgressResponse )
78
82
self .add_api_route ("/sdapi/v1/interrogate" , self .interrogateapi , methods = ["POST" ])
79
83
self .add_api_route ("/sdapi/v1/interrupt" , self .interruptapi , methods = ["POST" ])
84
+ self .add_api_route ("/sdapi/v1/skip" , self .skip , methods = ["POST" ])
80
85
self .add_api_route ("/sdapi/v1/options" , self .get_config , methods = ["GET" ], response_model = OptionsModel )
81
86
self .add_api_route ("/sdapi/v1/options" , self .set_config , methods = ["POST" ])
82
87
self .add_api_route ("/sdapi/v1/cmd-flags" , self .get_cmd_flags , methods = ["GET" ], response_model = FlagsModel )
@@ -103,14 +108,9 @@ def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
103
108
raise HTTPException (status_code = 401 , detail = "Incorrect username or password" , headers = {"WWW-Authenticate" : "Basic" })
104
109
105
110
def text2imgapi (self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ):
106
- sampler_index = sampler_to_index (txt2imgreq .sampler_index )
107
-
108
- if sampler_index is None :
109
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
110
-
111
111
populate = txt2imgreq .copy (update = { # Override __init__ params
112
112
"sd_model" : shared .sd_model ,
113
- "sampler_index " : sampler_index [ 0 ] ,
113
+ "sampler_name " : validate_sampler_name ( txt2imgreq . sampler_index ) ,
114
114
"do_not_save_samples" : True ,
115
115
"do_not_save_grid" : True
116
116
}
@@ -130,12 +130,6 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
130
130
return TextToImageResponse (images = b64images , parameters = vars (txt2imgreq ), info = processed .js ())
131
131
132
132
def img2imgapi (self , img2imgreq : StableDiffusionImg2ImgProcessingAPI ):
133
- sampler_index = sampler_to_index (img2imgreq .sampler_index )
134
-
135
- if sampler_index is None :
136
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
137
-
138
-
139
133
init_images = img2imgreq .init_images
140
134
if init_images is None :
141
135
raise HTTPException (status_code = 404 , detail = "Init image not found" )
@@ -144,10 +138,9 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
144
138
if mask :
145
139
mask = decode_base64_to_image (mask )
146
140
147
-
148
141
populate = img2imgreq .copy (update = { # Override __init__ params
149
142
"sd_model" : shared .sd_model ,
150
- "sampler_index " : sampler_index [ 0 ] ,
143
+ "sampler_name " : validate_sampler_name ( img2imgreq . sampler_index ) ,
151
144
"do_not_save_samples" : True ,
152
145
"do_not_save_grid" : True ,
153
146
"mask" : mask
@@ -266,6 +259,9 @@ def interruptapi(self):
266
259
267
260
return {}
268
261
262
+ def skip (self ):
263
+ shared .state .skip ()
264
+
269
265
def get_config (self ):
270
266
options = {}
271
267
for key in shared .opts .data .keys ():
@@ -277,14 +273,10 @@ def get_config(self):
277
273
278
274
return options
279
275
280
- def set_config (self , req : OptionsModel ):
281
- # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
282
- # overwrite all options with default values.
283
- raise RuntimeError ('Setting options via API is not supported' )
284
-
285
- reqDict = vars (req )
286
- for o in reqDict :
287
- setattr (shared .opts , o , reqDict [o ])
276
+ def set_config (self , req : Dict [str , Any ]):
277
+
278
+ for o in req :
279
+ setattr (shared .opts , o , req [o ])
288
280
289
281
shared .opts .save (shared .config_filename )
290
282
return
@@ -293,7 +285,7 @@ def get_cmd_flags(self):
293
285
return vars (shared .cmd_opts )
294
286
295
287
def get_samplers (self ):
296
- return [{"name" :sampler [0 ], "aliases" :sampler [2 ], "options" :sampler [3 ]} for sampler in all_samplers ]
288
+ return [{"name" :sampler [0 ], "aliases" :sampler [2 ], "options" :sampler [3 ]} for sampler in sd_samplers . all_samplers ]
297
289
298
290
def get_upscalers (self ):
299
291
upscalers = []
0 commit comments