|
5 | 5 | from threading import Lock
|
6 | 6 | from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
7 | 7 | from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
| 8 | +from fastapi.security import HTTPBasic, HTTPBasicCredentials |
| 9 | +from secrets import compare_digest |
| 10 | + |
8 | 11 | import modules.shared as shared
|
9 | 12 | from modules.api.models import *
|
10 | 13 | from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
@@ -57,29 +60,47 @@ def encode_pil_to_base64(image):
|
57 | 60 |
|
58 | 61 | class Api:
|
59 | 62 | def __init__(self, app: FastAPI, queue_lock: Lock):
|
| 63 | + if shared.cmd_opts.api_auth: |
| 64 | + self.credenticals = dict() |
| 65 | + for auth in shared.cmd_opts.api_auth.split(","): |
| 66 | + user, password = auth.split(":") |
| 67 | + self.credenticals[user] = password |
| 68 | + |
60 | 69 | self.router = APIRouter()
|
61 | 70 | self.app = app
|
62 | 71 | self.queue_lock = queue_lock
|
63 |
| - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) |
64 |
| - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) |
65 |
| - self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) |
66 |
| - self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) |
67 |
| - self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) |
68 |
| - self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) |
69 |
| - self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) |
70 |
| - self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) |
71 |
| - self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) |
72 |
| - self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) |
73 |
| - self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) |
74 |
| - self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) |
75 |
| - self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) |
76 |
| - self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) |
77 |
| - self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) |
78 |
| - self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) |
79 |
| - self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) |
80 |
| - self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) |
81 |
| - self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) |
82 |
| - self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) |
| 72 | + self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) |
| 73 | + self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) |
| 74 | + self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) |
| 75 | + self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) |
| 76 | + self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) |
| 77 | + self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) |
| 78 | + self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) |
| 79 | + self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) |
| 80 | + self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) |
| 81 | + self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) |
| 82 | + self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) |
| 83 | + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) |
| 84 | + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) |
| 85 | + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) |
| 86 | + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) |
| 87 | + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) |
| 88 | + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) |
| 89 | + self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) |
| 90 | + self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) |
| 91 | + self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) |
| 92 | + |
| 93 | + def add_api_route(self, path: str, endpoint, **kwargs): |
| 94 | + if shared.cmd_opts.api_auth: |
| 95 | + return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) |
| 96 | + return self.app.add_api_route(path, endpoint, **kwargs) |
| 97 | + |
| 98 | + def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())): |
| 99 | + if credenticals.username in self.credenticals: |
| 100 | + if compare_digest(credenticals.password, self.credenticals[credenticals.username]): |
| 101 | + return True |
| 102 | + |
| 103 | + raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) |
83 | 104 |
|
84 | 105 | def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
85 | 106 | sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
|
0 commit comments