Skip to content

Commit 371c4b9

Browse files
Merge pull request #4218 from bamarillo/utils-endpoints
[API][Feature] Utils endpoints
2 parents f674c48 + 17bd3f4 commit 371c4b9

File tree

3 files changed

+210
-8
lines changed

3 files changed

+210
-8
lines changed

modules/api/api.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
import io
33
import time
44
import uvicorn
5-
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
6-
from fastapi import APIRouter, Depends, HTTPException
5+
from threading import Lock
6+
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
7+
from fastapi import APIRouter, Depends, FastAPI, HTTPException
78
import modules.shared as shared
89
from modules.api.models import *
910
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
10-
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
11+
from modules.sd_samplers import all_samplers
1112
from modules.extras import run_extras, run_pnginfo
12-
13+
from modules.sd_models import checkpoints_list
14+
from modules.realesrgan_model import get_realesrgan_models
15+
from typing import List
1316

1417
def upscaler_to_index(name: str):
1518
try:
@@ -37,7 +40,7 @@ def encode_pil_to_base64(image):
3740

3841

3942
class Api:
40-
def __init__(self, app, queue_lock):
43+
def __init__(self, app: FastAPI, queue_lock: Lock):
4144
self.router = APIRouter()
4245
self.app = app
4346
self.queue_lock = queue_lock
@@ -48,6 +51,18 @@ def __init__(self, app, queue_lock):
4851
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
4952
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
5053
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
54+
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
55+
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
56+
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
57+
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
58+
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
59+
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
60+
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
61+
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
62+
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
63+
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
64+
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
65+
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
5166

5267
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
5368
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -190,6 +205,66 @@ def interruptapi(self):
190205
shared.state.interrupt()
191206

192207
return {}
208+
209+
def get_config(self):
210+
options = {}
211+
for key in shared.opts.data.keys():
212+
metadata = shared.opts.data_labels.get(key)
213+
if(metadata is not None):
214+
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
215+
else:
216+
options.update({key: shared.opts.data.get(key, None)})
217+
218+
return options
219+
220+
def set_config(self, req: OptionsModel):
221+
reqDict = vars(req)
222+
for o in reqDict:
223+
setattr(shared.opts, o, reqDict[o])
224+
225+
shared.opts.save(shared.config_filename)
226+
return
227+
228+
def get_cmd_flags(self):
229+
return vars(shared.cmd_opts)
230+
231+
def get_samplers(self):
232+
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
233+
234+
def get_upscalers(self):
235+
upscalers = []
236+
237+
for upscaler in shared.sd_upscalers:
238+
u = upscaler.scaler
239+
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
240+
241+
return upscalers
242+
243+
def get_sd_models(self):
244+
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
245+
246+
def get_hypernetworks(self):
247+
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
248+
249+
def get_face_restorers(self):
250+
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
251+
252+
def get_realesrgan_models(self):
253+
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
254+
255+
def get_promp_styles(self):
256+
styleList = []
257+
for k in shared.prompt_styles.styles:
258+
style = shared.prompt_styles.styles[k]
259+
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
260+
261+
return styleList
262+
263+
def get_artists_categories(self):
264+
return shared.artist_db.cats
265+
266+
def get_artists(self):
267+
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
193268

194269
def launch(self, server_name, port):
195270
self.app.include_router(self.router)

modules/api/models.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import inspect
2-
from click import prompt
32
from pydantic import BaseModel, Field, create_model
4-
from typing import Any, Optional
3+
from typing import Any, Optional, Union
54
from typing_extensions import Literal
65
from inflection import underscore
76
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
8-
from modules.shared import sd_upscalers
7+
from modules.shared import sd_upscalers, opts, parser
98

109
API_NOT_ALLOWED = [
1110
"self",
@@ -166,3 +165,68 @@ class ProgressResponse(BaseModel):
166165
eta_relative: float = Field(title="ETA in secs")
167166
state: dict = Field(title="State", description="The current state snapshot")
168167
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
168+
169+
fields = {}
170+
for key, value in opts.data.items():
171+
metadata = opts.data_labels.get(key)
172+
optType = opts.typemap.get(type(value), type(value))
173+
174+
if (metadata is not None):
175+
fields.update({key: (Optional[optType], Field(
176+
default=metadata.default ,description=metadata.label))})
177+
else:
178+
fields.update({key: (Optional[optType], Field())})
179+
180+
OptionsModel = create_model("Options", **fields)
181+
182+
flags = {}
183+
_options = vars(parser)['_option_string_actions']
184+
for key in _options:
185+
if(_options[key].dest != 'help'):
186+
flag = _options[key]
187+
_type = str
188+
if(_options[key].default != None): _type = type(_options[key].default)
189+
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
190+
191+
FlagsModel = create_model("Flags", **flags)
192+
193+
class SamplerItem(BaseModel):
194+
name: str = Field(title="Name")
195+
aliases: list[str] = Field(title="Aliases")
196+
options: dict[str, str] = Field(title="Options")
197+
198+
class UpscalerItem(BaseModel):
199+
name: str = Field(title="Name")
200+
model_name: str | None = Field(title="Model Name")
201+
model_path: str | None = Field(title="Path")
202+
model_url: str | None = Field(title="URL")
203+
204+
class SDModelItem(BaseModel):
205+
title: str = Field(title="Title")
206+
model_name: str = Field(title="Model Name")
207+
hash: str = Field(title="Hash")
208+
filename: str = Field(title="Filename")
209+
config: str = Field(title="Config file")
210+
211+
class HypernetworkItem(BaseModel):
212+
name: str = Field(title="Name")
213+
path: str | None = Field(title="Path")
214+
215+
class FaceRestorerItem(BaseModel):
216+
name: str = Field(title="Name")
217+
cmd_dir: str | None = Field(title="Path")
218+
219+
class RealesrganItem(BaseModel):
220+
name: str = Field(title="Name")
221+
path: str | None = Field(title="Path")
222+
scale: int | None = Field(title="Scale")
223+
224+
class PromptStyleItem(BaseModel):
225+
name: str = Field(title="Name")
226+
prompt: str | None = Field(title="Prompt")
227+
negative_prompt: str | None = Field(title="Negative Prompt")
228+
229+
class ArtistItem(BaseModel):
230+
name: str = Field(title="Name")
231+
score: float = Field(title="Score")
232+
category: str = Field(title="Category")

test/utils_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
import requests
3+
4+
class UtilsTests(unittest.TestCase):
5+
def setUp(self):
6+
self.url_options = "http://localhost:7860/sdapi/v1/options"
7+
self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
8+
self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
9+
self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
10+
self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
11+
self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
12+
self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
13+
self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
14+
self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
15+
self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
16+
self.url_artists = "http://localhost:7860/sdapi/v1/artists"
17+
18+
def test_options_get(self):
19+
self.assertEqual(requests.get(self.url_options).status_code, 200)
20+
21+
def test_options_write(self):
22+
response = requests.get(self.url_options)
23+
self.assertEqual(response.status_code, 200)
24+
25+
pre_value = response.json()["send_seed"]
26+
27+
self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
28+
29+
response = requests.get(self.url_options)
30+
self.assertEqual(response.status_code, 200)
31+
self.assertEqual(response.json()["send_seed"], not pre_value)
32+
33+
requests.post(self.url_options, json={"send_seed": pre_value})
34+
35+
def test_cmd_flags(self):
36+
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
37+
38+
def test_samplers(self):
39+
self.assertEqual(requests.get(self.url_samplers).status_code, 200)
40+
41+
def test_upscalers(self):
42+
self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
43+
44+
def test_sd_models(self):
45+
self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
46+
47+
def test_hypernetworks(self):
48+
self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
49+
50+
def test_face_restorers(self):
51+
self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
52+
53+
def test_realesrgan_models(self):
54+
self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
55+
56+
def test_prompt_styles(self):
57+
self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
58+
59+
def test_artist_categories(self):
60+
self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
61+
62+
def test_artists(self):
63+
self.assertEqual(requests.get(self.url_artists).status_code, 200)

0 commit comments

Comments
 (0)