diff --git a/python/mlc_llm/cli/serve.py b/python/mlc_llm/cli/serve.py index bf1f5e5a36..4495154fbd 100644 --- a/python/mlc_llm/cli/serve.py +++ b/python/mlc_llm/cli/serve.py @@ -194,6 +194,12 @@ def main(argv): default=["*"], help="allowed headers" + ' (default: "%(default)s")', ) + parser.add_argument( + "--api-key", + type=str, + default=None, + help="API key for authentication. If not provided, authentication is disabled.", + ) parsed = parser.parse_args(argv) additional_models = [] @@ -236,4 +242,5 @@ def main(argv): allow_origins=parsed.allow_origins, allow_methods=parsed.allow_methods, allow_headers=parsed.allow_headers, + api_key=parsed.api_key, ) diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index c00ed1adc5..ae571aebfd 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -51,6 +51,7 @@ def serve( allow_origins: Any, allow_methods: Any, allow_headers: Any, + api_key: Optional[str] = None, ): # pylint: disable=too-many-arguments, too-many-locals """Serve the model with the specified configuration.""" # Create engine and start the background loop @@ -84,6 +85,7 @@ def serve( with ServerContext() as server_context: server_context.add_model(model, async_engine) + server_context.api_key = api_key app = fastapi.FastAPI() app.add_middleware( diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 18a415e413..c4054f2d7b 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -18,7 +18,20 @@ from mlc_llm.serve import engine_base, engine_utils from mlc_llm.serve.server import ServerContext -app = fastapi.APIRouter() + +def verify_api_key(request: fastapi.Request): + """Function to verify API key""" + server_context = ServerContext.current() + # Only perform verification when API key is configured + if server_context is not None and server_context.api_key is not None: + provided_key = request.headers.get("Authorization", "").replace("Bearer ", "") + if provided_key != server_context.api_key: + raise fastapi.HTTPException(status_code=401, detail="Invalid API Key") + # Skip verification if no API key is configured + + +app = fastapi.APIRouter(dependencies=[fastapi.Depends(verify_api_key)]) + ################ v1/models ################ diff --git a/python/mlc_llm/serve/server/server_context.py b/python/mlc_llm/serve/server/server_context.py index 2f4bf26626..ea38d967ab 100644 --- a/python/mlc_llm/serve/server/server_context.py +++ b/python/mlc_llm/serve/server/server_context.py @@ -15,6 +15,7 @@ class ServerContext: def __init__(self) -> None: self._models: Dict[str, AsyncMLCEngine] = {} + self.api_key: Optional[str] = None # New API key property def __enter__(self): if ServerContext.server_context is not None: