5
5
from threading import Lock
6
6
from gradio .processing_utils import encode_pil_to_base64 , decode_base64_to_file , decode_base64_to_image
7
7
from fastapi import APIRouter , Depends , FastAPI , HTTPException
8
+ from fastapi .security import HTTPBasic , HTTPBasicCredentials
9
+ from secrets import compare_digest
10
+
8
11
import modules .shared as shared
12
+ from modules import sd_samplers , deepbooru
9
13
from modules .api .models import *
10
14
from modules .processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
11
- from modules .sd_samplers import all_samplers
12
15
from modules .extras import run_extras , run_pnginfo
13
16
from PIL import PngImagePlugin
14
17
from modules .sd_models import checkpoints_list
15
18
from modules .realesrgan_model import get_realesrgan_models
16
19
from typing import List
17
20
18
- if shared .cmd_opts .deepdanbooru :
19
- from modules .deepbooru import get_deepbooru_tags
20
-
21
21
def upscaler_to_index (name : str ):
22
22
try :
23
23
return [x .name .lower () for x in shared .sd_upscalers ].index (name .lower ())
24
24
except :
25
25
raise HTTPException (status_code = 400 , detail = f"Invalid upscaler, needs to be on of these: { ' , ' .join ([x .name for x in sd_upscalers ])} " )
26
26
27
27
28
- sampler_to_index = lambda name : next (filter (lambda row : name .lower () == row [1 ].name .lower (), enumerate (all_samplers )), None )
28
+ def validate_sampler_name (name ):
29
+ config = sd_samplers .all_samplers_map .get (name , None )
30
+ if config is None :
31
+ raise HTTPException (status_code = 404 , detail = "Sampler not found" )
29
32
33
+ return name
30
34
31
35
def setUpscalers (req : dict ):
32
36
reqDict = vars (req )
@@ -57,39 +61,53 @@ def encode_pil_to_base64(image):
57
61
58
62
class Api :
59
63
def __init__ (self , app : FastAPI , queue_lock : Lock ):
64
+ if shared .cmd_opts .api_auth :
65
+ self .credenticals = dict ()
66
+ for auth in shared .cmd_opts .api_auth .split ("," ):
67
+ user , password = auth .split (":" )
68
+ self .credenticals [user ] = password
69
+
60
70
self .router = APIRouter ()
61
71
self .app = app
62
72
self .queue_lock = queue_lock
63
- self .app .add_api_route ("/sdapi/v1/txt2img" , self .text2imgapi , methods = ["POST" ], response_model = TextToImageResponse )
64
- self .app .add_api_route ("/sdapi/v1/img2img" , self .img2imgapi , methods = ["POST" ], response_model = ImageToImageResponse )
65
- self .app .add_api_route ("/sdapi/v1/extra-single-image" , self .extras_single_image_api , methods = ["POST" ], response_model = ExtrasSingleImageResponse )
66
- self .app .add_api_route ("/sdapi/v1/extra-batch-images" , self .extras_batch_images_api , methods = ["POST" ], response_model = ExtrasBatchImagesResponse )
67
- self .app .add_api_route ("/sdapi/v1/png-info" , self .pnginfoapi , methods = ["POST" ], response_model = PNGInfoResponse )
68
- self .app .add_api_route ("/sdapi/v1/progress" , self .progressapi , methods = ["GET" ], response_model = ProgressResponse )
69
- self .app .add_api_route ("/sdapi/v1/interrogate" , self .interrogateapi , methods = ["POST" ])
70
- self .app .add_api_route ("/sdapi/v1/interrupt" , self .interruptapi , methods = ["POST" ])
71
- self .app .add_api_route ("/sdapi/v1/options" , self .get_config , methods = ["GET" ], response_model = OptionsModel )
72
- self .app .add_api_route ("/sdapi/v1/options" , self .set_config , methods = ["POST" ])
73
- self .app .add_api_route ("/sdapi/v1/cmd-flags" , self .get_cmd_flags , methods = ["GET" ], response_model = FlagsModel )
74
- self .app .add_api_route ("/sdapi/v1/samplers" , self .get_samplers , methods = ["GET" ], response_model = List [SamplerItem ])
75
- self .app .add_api_route ("/sdapi/v1/upscalers" , self .get_upscalers , methods = ["GET" ], response_model = List [UpscalerItem ])
76
- self .app .add_api_route ("/sdapi/v1/sd-models" , self .get_sd_models , methods = ["GET" ], response_model = List [SDModelItem ])
77
- self .app .add_api_route ("/sdapi/v1/hypernetworks" , self .get_hypernetworks , methods = ["GET" ], response_model = List [HypernetworkItem ])
78
- self .app .add_api_route ("/sdapi/v1/face-restorers" , self .get_face_restorers , methods = ["GET" ], response_model = List [FaceRestorerItem ])
79
- self .app .add_api_route ("/sdapi/v1/realesrgan-models" , self .get_realesrgan_models , methods = ["GET" ], response_model = List [RealesrganItem ])
80
- self .app .add_api_route ("/sdapi/v1/prompt-styles" , self .get_promp_styles , methods = ["GET" ], response_model = List [PromptStyleItem ])
81
- self .app .add_api_route ("/sdapi/v1/artist-categories" , self .get_artists_categories , methods = ["GET" ], response_model = List [str ])
82
- self .app .add_api_route ("/sdapi/v1/artists" , self .get_artists , methods = ["GET" ], response_model = List [ArtistItem ])
73
+ self .add_api_route ("/sdapi/v1/txt2img" , self .text2imgapi , methods = ["POST" ], response_model = TextToImageResponse )
74
+ self .add_api_route ("/sdapi/v1/img2img" , self .img2imgapi , methods = ["POST" ], response_model = ImageToImageResponse )
75
+ self .add_api_route ("/sdapi/v1/extra-single-image" , self .extras_single_image_api , methods = ["POST" ], response_model = ExtrasSingleImageResponse )
76
+ self .add_api_route ("/sdapi/v1/extra-batch-images" , self .extras_batch_images_api , methods = ["POST" ], response_model = ExtrasBatchImagesResponse )
77
+ self .add_api_route ("/sdapi/v1/png-info" , self .pnginfoapi , methods = ["POST" ], response_model = PNGInfoResponse )
78
+ self .add_api_route ("/sdapi/v1/progress" , self .progressapi , methods = ["GET" ], response_model = ProgressResponse )
79
+ self .add_api_route ("/sdapi/v1/interrogate" , self .interrogateapi , methods = ["POST" ])
80
+ self .add_api_route ("/sdapi/v1/interrupt" , self .interruptapi , methods = ["POST" ])
81
+ self .add_api_route ("/sdapi/v1/skip" , self .skip , methods = ["POST" ])
82
+ self .add_api_route ("/sdapi/v1/options" , self .get_config , methods = ["GET" ], response_model = OptionsModel )
83
+ self .add_api_route ("/sdapi/v1/options" , self .set_config , methods = ["POST" ])
84
+ self .add_api_route ("/sdapi/v1/cmd-flags" , self .get_cmd_flags , methods = ["GET" ], response_model = FlagsModel )
85
+ self .add_api_route ("/sdapi/v1/samplers" , self .get_samplers , methods = ["GET" ], response_model = List [SamplerItem ])
86
+ self .add_api_route ("/sdapi/v1/upscalers" , self .get_upscalers , methods = ["GET" ], response_model = List [UpscalerItem ])
87
+ self .add_api_route ("/sdapi/v1/sd-models" , self .get_sd_models , methods = ["GET" ], response_model = List [SDModelItem ])
88
+ self .add_api_route ("/sdapi/v1/hypernetworks" , self .get_hypernetworks , methods = ["GET" ], response_model = List [HypernetworkItem ])
89
+ self .add_api_route ("/sdapi/v1/face-restorers" , self .get_face_restorers , methods = ["GET" ], response_model = List [FaceRestorerItem ])
90
+ self .add_api_route ("/sdapi/v1/realesrgan-models" , self .get_realesrgan_models , methods = ["GET" ], response_model = List [RealesrganItem ])
91
+ self .add_api_route ("/sdapi/v1/prompt-styles" , self .get_promp_styles , methods = ["GET" ], response_model = List [PromptStyleItem ])
92
+ self .add_api_route ("/sdapi/v1/artist-categories" , self .get_artists_categories , methods = ["GET" ], response_model = List [str ])
93
+ self .add_api_route ("/sdapi/v1/artists" , self .get_artists , methods = ["GET" ], response_model = List [ArtistItem ])
94
+
95
+ def add_api_route (self , path : str , endpoint , ** kwargs ):
96
+ if shared .cmd_opts .api_auth :
97
+ return self .app .add_api_route (path , endpoint , dependencies = [Depends (self .auth )], ** kwargs )
98
+ return self .app .add_api_route (path , endpoint , ** kwargs )
99
+
100
+ def auth (self , credenticals : HTTPBasicCredentials = Depends (HTTPBasic ())):
101
+ if credenticals .username in self .credenticals :
102
+ if compare_digest (credenticals .password , self .credenticals [credenticals .username ]):
103
+ return True
104
+
105
+ raise HTTPException (status_code = 401 , detail = "Incorrect username or password" , headers = {"WWW-Authenticate" : "Basic" })
83
106
84
107
def text2imgapi (self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ):
85
- sampler_index = sampler_to_index (txt2imgreq .sampler_index )
86
-
87
- if sampler_index is None :
88
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
89
-
90
108
populate = txt2imgreq .copy (update = { # Override __init__ params
91
109
"sd_model" : shared .sd_model ,
92
- "sampler_index " : sampler_index [ 0 ] ,
110
+ "sampler_name " : validate_sampler_name ( txt2imgreq . sampler_index ) ,
93
111
"do_not_save_samples" : True ,
94
112
"do_not_save_grid" : True
95
113
}
@@ -109,12 +127,6 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
109
127
return TextToImageResponse (images = b64images , parameters = vars (txt2imgreq ), info = processed .js ())
110
128
111
129
def img2imgapi (self , img2imgreq : StableDiffusionImg2ImgProcessingAPI ):
112
- sampler_index = sampler_to_index (img2imgreq .sampler_index )
113
-
114
- if sampler_index is None :
115
- raise HTTPException (status_code = 404 , detail = "Sampler not found" )
116
-
117
-
118
130
init_images = img2imgreq .init_images
119
131
if init_images is None :
120
132
raise HTTPException (status_code = 404 , detail = "Init image not found" )
@@ -123,10 +135,9 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
123
135
if mask :
124
136
mask = decode_base64_to_image (mask )
125
137
126
-
127
138
populate = img2imgreq .copy (update = { # Override __init__ params
128
139
"sd_model" : shared .sd_model ,
129
- "sampler_index " : sampler_index [ 0 ] ,
140
+ "sampler_name " : validate_sampler_name ( img2imgreq . sampler_index ) ,
130
141
"do_not_save_samples" : True ,
131
142
"do_not_save_grid" : True ,
132
143
"mask" : mask
@@ -231,10 +242,7 @@ def interrogateapi(self, interrogatereq: InterrogateRequest):
231
242
if interrogatereq .model == "clip" :
232
243
processed = shared .interrogator .interrogate (img )
233
244
elif interrogatereq .model == "deepdanbooru" :
234
- if shared .cmd_opts .deepdanbooru :
235
- processed = get_deepbooru_tags (img )
236
- else :
237
- raise HTTPException (status_code = 404 , detail = "Model not found. Add --deepdanbooru when launching for using the model." )
245
+ processed = deepbooru .model .tag (img )
238
246
else :
239
247
raise HTTPException (status_code = 404 , detail = "Model not found" )
240
248
@@ -245,6 +253,9 @@ def interruptapi(self):
245
253
246
254
return {}
247
255
256
+ def skip (self ):
257
+ shared .state .skip ()
258
+
248
259
def get_config (self ):
249
260
options = {}
250
261
for key in shared .opts .data .keys ():
@@ -256,14 +267,9 @@ def get_config(self):
256
267
257
268
return options
258
269
259
- def set_config (self , req : OptionsModel ):
260
- # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
261
- # overwrite all options with default values.
262
- raise RuntimeError ('Setting options via API is not supported' )
263
-
264
- reqDict = vars (req )
265
- for o in reqDict :
266
- setattr (shared .opts , o , reqDict [o ])
270
+ def set_config (self , req : Dict [str , Any ]):
271
+ for k , v in req .items ():
272
+ shared .opts .set (k , v )
267
273
268
274
shared .opts .save (shared .config_filename )
269
275
return
@@ -272,7 +278,7 @@ def get_cmd_flags(self):
272
278
return vars (shared .cmd_opts )
273
279
274
280
def get_samplers (self ):
275
- return [{"name" :sampler [0 ], "aliases" :sampler [2 ], "options" :sampler [3 ]} for sampler in all_samplers ]
281
+ return [{"name" : sampler [0 ], "aliases" :sampler [2 ], "options" :sampler [3 ]} for sampler in sd_samplers . all_samplers ]
276
282
277
283
def get_upscalers (self ):
278
284
upscalers = []
0 commit comments