2
2
import io
3
3
import time
4
4
import uvicorn
5
- from gradio .processing_utils import decode_base64_to_file , decode_base64_to_image
6
- from fastapi import APIRouter , Depends , HTTPException
5
+ from threading import Lock
6
+ from gradio .processing_utils import encode_pil_to_base64 , decode_base64_to_file , decode_base64_to_image
7
+ from fastapi import APIRouter , Depends , FastAPI , HTTPException
7
8
import modules .shared as shared
8
9
from modules .api .models import *
9
10
from modules .processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
10
- from modules .sd_samplers import all_samplers , sample_to_image , samples_to_image_grid
11
+ from modules .sd_samplers import all_samplers
11
12
from modules .extras import run_extras , run_pnginfo
12
-
13
+ from modules .sd_models import checkpoints_list
14
+ from modules .realesrgan_model import get_realesrgan_models
15
+ from typing import List
13
16
14
17
def upscaler_to_index (name : str ):
15
18
try :
@@ -37,7 +40,7 @@ def encode_pil_to_base64(image):
37
40
38
41
39
42
class Api :
40
- def __init__ (self , app , queue_lock ):
43
+ def __init__ (self , app : FastAPI , queue_lock : Lock ):
41
44
self .router = APIRouter ()
42
45
self .app = app
43
46
self .queue_lock = queue_lock
@@ -48,6 +51,18 @@ def __init__(self, app, queue_lock):
48
51
self .app .add_api_route ("/sdapi/v1/png-info" , self .pnginfoapi , methods = ["POST" ], response_model = PNGInfoResponse )
49
52
self .app .add_api_route ("/sdapi/v1/progress" , self .progressapi , methods = ["GET" ], response_model = ProgressResponse )
50
53
self .app .add_api_route ("/sdapi/v1/interrupt" , self .interruptapi , methods = ["POST" ])
54
+ self .app .add_api_route ("/sdapi/v1/options" , self .get_config , methods = ["GET" ], response_model = OptionsModel )
55
+ self .app .add_api_route ("/sdapi/v1/options" , self .set_config , methods = ["POST" ])
56
+ self .app .add_api_route ("/sdapi/v1/cmd-flags" , self .get_cmd_flags , methods = ["GET" ], response_model = FlagsModel )
57
+ self .app .add_api_route ("/sdapi/v1/samplers" , self .get_samplers , methods = ["GET" ], response_model = List [SamplerItem ])
58
+ self .app .add_api_route ("/sdapi/v1/upscalers" , self .get_upscalers , methods = ["GET" ], response_model = List [UpscalerItem ])
59
+ self .app .add_api_route ("/sdapi/v1/sd-models" , self .get_sd_models , methods = ["GET" ], response_model = List [SDModelItem ])
60
+ self .app .add_api_route ("/sdapi/v1/hypernetworks" , self .get_hypernetworks , methods = ["GET" ], response_model = List [HypernetworkItem ])
61
+ self .app .add_api_route ("/sdapi/v1/face-restorers" , self .get_face_restorers , methods = ["GET" ], response_model = List [FaceRestorerItem ])
62
+ self .app .add_api_route ("/sdapi/v1/realesrgan-models" , self .get_realesrgan_models , methods = ["GET" ], response_model = List [RealesrganItem ])
63
+ self .app .add_api_route ("/sdapi/v1/prompt-styles" , self .get_promp_styles , methods = ["GET" ], response_model = List [PromptStyleItem ])
64
+ self .app .add_api_route ("/sdapi/v1/artist-categories" , self .get_artists_categories , methods = ["GET" ], response_model = List [str ])
65
+ self .app .add_api_route ("/sdapi/v1/artists" , self .get_artists , methods = ["GET" ], response_model = List [ArtistItem ])
51
66
52
67
def text2imgapi (self , txt2imgreq : StableDiffusionTxt2ImgProcessingAPI ):
53
68
sampler_index = sampler_to_index (txt2imgreq .sampler_index )
@@ -190,6 +205,66 @@ def interruptapi(self):
190
205
shared .state .interrupt ()
191
206
192
207
return {}
208
+
209
+ def get_config (self ):
210
+ options = {}
211
+ for key in shared .opts .data .keys ():
212
+ metadata = shared .opts .data_labels .get (key )
213
+ if (metadata is not None ):
214
+ options .update ({key : shared .opts .data .get (key , shared .opts .data_labels .get (key ).default )})
215
+ else :
216
+ options .update ({key : shared .opts .data .get (key , None )})
217
+
218
+ return options
219
+
220
+ def set_config (self , req : OptionsModel ):
221
+ reqDict = vars (req )
222
+ for o in reqDict :
223
+ setattr (shared .opts , o , reqDict [o ])
224
+
225
+ shared .opts .save (shared .config_filename )
226
+ return
227
+
228
+ def get_cmd_flags (self ):
229
+ return vars (shared .cmd_opts )
230
+
231
+ def get_samplers (self ):
232
+ return [{"name" :sampler [0 ], "aliases" :sampler [2 ], "options" :sampler [3 ]} for sampler in all_samplers ]
233
+
234
+ def get_upscalers (self ):
235
+ upscalers = []
236
+
237
+ for upscaler in shared .sd_upscalers :
238
+ u = upscaler .scaler
239
+ upscalers .append ({"name" :u .name , "model_name" :u .model_name , "model_path" :u .model_path , "model_url" :u .model_url })
240
+
241
+ return upscalers
242
+
243
+ def get_sd_models (self ):
244
+ return [{"title" :x .title , "model_name" :x .model_name , "hash" :x .hash , "filename" : x .filename , "config" : x .config } for x in checkpoints_list .values ()]
245
+
246
+ def get_hypernetworks (self ):
247
+ return [{"name" : name , "path" : shared .hypernetworks [name ]} for name in shared .hypernetworks ]
248
+
249
+ def get_face_restorers (self ):
250
+ return [{"name" :x .name (), "cmd_dir" : getattr (x , "cmd_dir" , None )} for x in shared .face_restorers ]
251
+
252
+ def get_realesrgan_models (self ):
253
+ return [{"name" :x .name ,"path" :x .data_path , "scale" :x .scale } for x in get_realesrgan_models (None )]
254
+
255
+ def get_promp_styles (self ):
256
+ styleList = []
257
+ for k in shared .prompt_styles .styles :
258
+ style = shared .prompt_styles .styles [k ]
259
+ styleList .append ({"name" :style [0 ], "prompt" : style [1 ], "negative_prompr" : style [2 ]})
260
+
261
+ return styleList
262
+
263
+ def get_artists_categories (self ):
264
+ return shared .artist_db .cats
265
+
266
+ def get_artists (self ):
267
+ return [{"name" :x [0 ], "score" :x [1 ], "category" :x [2 ]} for x in shared .artist_db .artists ]
193
268
194
269
def launch (self , server_name , port ):
195
270
self .app .include_router (self .router )
0 commit comments