Skip to content

Commit 41e242b

Browse files
Merge pull request #4733 from MaikoTan/api-authorization
feat: add http basic authentication for api
2 parents 5a6387e + 336c341 commit 41e242b

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

modules/api/api.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from threading import Lock
66
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
77
from fastapi import APIRouter, Depends, FastAPI, HTTPException
8+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
9+
from secrets import compare_digest
10+
811
import modules.shared as shared
912
from modules import sd_samplers
1013
from modules.api.models import *
@@ -61,30 +64,48 @@ def encode_pil_to_base64(image):
6164

6265
class Api:
6366
def __init__(self, app: FastAPI, queue_lock: Lock):
67+
if shared.cmd_opts.api_auth:
68+
self.credenticals = dict()
69+
for auth in shared.cmd_opts.api_auth.split(","):
70+
user, password = auth.split(":")
71+
self.credenticals[user] = password
72+
6473
self.router = APIRouter()
6574
self.app = app
6675
self.queue_lock = queue_lock
67-
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
68-
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
69-
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
70-
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
71-
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
72-
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
73-
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
74-
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
75-
self.app.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
76-
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
77-
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
78-
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
79-
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
80-
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
81-
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
82-
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
83-
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
84-
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
85-
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
86-
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
87-
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
76+
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
77+
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
78+
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
79+
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
80+
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
81+
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
82+
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
83+
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
84+
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
85+
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
86+
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
87+
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
88+
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
89+
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
90+
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
91+
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
92+
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
93+
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
94+
self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
95+
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
96+
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
97+
98+
def add_api_route(self, path: str, endpoint, **kwargs):
99+
if shared.cmd_opts.api_auth:
100+
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
101+
return self.app.add_api_route(path, endpoint, **kwargs)
102+
103+
def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
104+
if credenticals.username in self.credenticals:
105+
if compare_digest(credenticals.password, self.credenticals[credenticals.username]):
106+
return True
107+
108+
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
88109

89110
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
90111
populate = txt2imgreq.copy(update={ # Override __init__ params

modules/shared.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
8282
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
8383
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
84+
parser.add_argument("--api-auth", type=str, help='Set authentication for api like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
8485
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
8586
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
8687
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)

0 commit comments

Comments
 (0)