|
5 | 5 | from threading import Lock
|
6 | 6 | from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
7 | 7 | from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
| 8 | +from fastapi.security import HTTPBasic, HTTPBasicCredentials |
| 9 | +from secrets import compare_digest |
| 10 | + |
8 | 11 | import modules.shared as shared
|
9 | 12 | from modules import sd_samplers
|
10 | 13 | from modules.api.models import *
|
@@ -61,30 +64,48 @@ def encode_pil_to_base64(image):
|
61 | 64 |
|
62 | 65 | class Api:
|
63 | 66 | def __init__(self, app: FastAPI, queue_lock: Lock):
|
| 67 | + if shared.cmd_opts.api_auth: |
| 68 | + self.credenticals = dict() |
| 69 | + for auth in shared.cmd_opts.api_auth.split(","): |
| 70 | + user, password = auth.split(":") |
| 71 | + self.credenticals[user] = password |
| 72 | + |
64 | 73 | self.router = APIRouter()
|
65 | 74 | self.app = app
|
66 | 75 | self.queue_lock = queue_lock
|
67 |
| - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) |
68 |
| - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) |
69 |
| - self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) |
70 |
| - self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) |
71 |
| - self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) |
72 |
| - self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) |
73 |
| - self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) |
74 |
| - self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) |
75 |
| - self.app.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) |
76 |
| - self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) |
77 |
| - self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) |
78 |
| - self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) |
79 |
| - self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) |
80 |
| - self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) |
81 |
| - self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) |
82 |
| - self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) |
83 |
| - self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) |
84 |
| - self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) |
85 |
| - self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) |
86 |
| - self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) |
87 |
| - self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) |
| 76 | + self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) |
| 77 | + self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) |
| 78 | + self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) |
| 79 | + self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) |
| 80 | + self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) |
| 81 | + self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) |
| 82 | + self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) |
| 83 | + self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) |
| 84 | + self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) |
| 85 | + self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) |
| 86 | + self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) |
| 87 | + self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) |
| 88 | + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) |
| 89 | + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) |
| 90 | + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) |
| 91 | + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) |
| 92 | + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) |
| 93 | + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) |
| 94 | + self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) |
| 95 | + self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) |
| 96 | + self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) |
| 97 | + |
| 98 | + def add_api_route(self, path: str, endpoint, **kwargs): |
| 99 | + if shared.cmd_opts.api_auth: |
| 100 | + return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) |
| 101 | + return self.app.add_api_route(path, endpoint, **kwargs) |
| 102 | + |
| 103 | + def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())): |
| 104 | + if credenticals.username in self.credenticals: |
| 105 | + if compare_digest(credenticals.password, self.credenticals[credenticals.username]): |
| 106 | + return True |
| 107 | + |
| 108 | + raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) |
88 | 109 |
|
89 | 110 | def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
90 | 111 | populate = txt2imgreq.copy(update={ # Override __init__ params
|
|
0 commit comments