Skip to content

Commit 7435225

Browse files
authored
forward bearer token from authorization header to OpenAI (#19)
* forward bearer token from authorization header to OpenAI * format
1 parent 12a208c commit 7435225

File tree

4 files changed

+74
-2
lines changed

4 files changed

+74
-2
lines changed

services/chat-backend/src/app.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
# setup logging, before importing anything else (and before creating other loggers)
44
setup_logging()
55

6-
import json
76
import logging
87

9-
from fastapi import FastAPI, Request, Response # noqa: E402
8+
from fastapi import FastAPI, Request # noqa: E402
109
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
1110
from fastapi.responses import JSONResponse # noqa: E402
1211
from fastapi.routing import APIRoute # noqa: E402
12+
from fastapi.openapi.utils import get_openapi
1313

1414
from src.core.config import get_settings # noqa: E402
1515
from src.middleware import AuditMiddleware # noqa: E402
16+
from src.middleware.request_context import RequestContextMiddleware # noqa: E402
1617
from src.routers.main import api_router # noqa: E402
1718
from src.utils.openapi_converter import convert_openapi_spec # noqa: E402
1819

@@ -35,6 +36,38 @@ def custom_generate_unique_id(route: APIRoute) -> str:
3536
generate_unique_id_function=custom_generate_unique_id,
3637
)
3738

39+
def custom_openapi():
40+
if app.openapi_schema:
41+
return app.openapi_schema
42+
43+
openapi_schema = get_openapi(
44+
title=app.title,
45+
version="1.0.0",
46+
description=app.description,
47+
routes=app.routes,
48+
)
49+
50+
# Add security scheme for Bearer token
51+
openapi_schema["components"]["securitySchemes"] = {
52+
"BearerAuth": {
53+
"type": "http",
54+
"scheme": "bearer",
55+
"bearerFormat": "JWT",
56+
"description": "Enter your bearer token for Azure AD authentication"
57+
}
58+
}
59+
60+
# Apply security globally to all endpoints
61+
for path in openapi_schema["paths"]:
62+
for method in openapi_schema["paths"][path]:
63+
if method != "parameters":
64+
openapi_schema["paths"][path][method]["security"] = [{"BearerAuth": []}]
65+
66+
app.openapi_schema = openapi_schema
67+
return app.openapi_schema
68+
69+
app.openapi = custom_openapi
70+
3871
app.add_middleware(
3972
CORSMiddleware,
4073
allow_origins=["*"],
@@ -43,6 +76,7 @@ def custom_generate_unique_id(route: APIRoute) -> str:
4376
allow_headers=["*"],
4477
)
4578

79+
app.add_middleware(RequestContextMiddleware)
4680
app.add_middleware(AuditMiddleware)
4781
app.include_router(api_router, prefix=config.API_PREFIX)
4882

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from fastapi import Request
2+
from starlette.middleware.base import BaseHTTPMiddleware
3+
from starlette.responses import Response
4+
5+
from src.utils.request_context import set_authorization_header
6+
7+
8+
class RequestContextMiddleware(BaseHTTPMiddleware):
9+
"""Middleware to extract request headers and store them in context variables."""
10+
11+
async def dispatch(self, request: Request, call_next):
12+
# Extract Authorization header
13+
auth_header = request.headers.get("Authorization")
14+
set_authorization_header(auth_header)
15+
16+
response = await call_next(request)
17+
return response

services/chat-backend/src/services/azure.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from azure.core.exceptions import ClientAuthenticationError
55
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
66
from tlm.config.types import ModelProvider
7+
from src.utils.request_context import get_authorization_header
78

89
logger = logging.getLogger(__name__)
910

@@ -27,6 +28,12 @@ def _is_azure_model(model_provider: ModelProvider) -> bool:
2728

2829

2930
def get_azure_ad_token() -> str | None:
31+
auth_header = get_authorization_header()
32+
if auth_header:
33+
# Bearer token
34+
logger.info(f"using bearer token: {auth_header}")
35+
return auth_header.split(" ")[1]
36+
3037
try:
3138
managed_identity_client_id = os.getenv("AZURE_CLIENT_ID")
3239

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from contextvars import ContextVar
2+
from typing import Optional
3+
4+
authorization_header: ContextVar[Optional[str]] = ContextVar("authorization_header", default=None)
5+
6+
7+
def get_authorization_header() -> Optional[str]:
8+
"""Get the Authorization header value from the current request context."""
9+
return authorization_header.get()
10+
11+
12+
def set_authorization_header(value: Optional[str]) -> None:
13+
"""Set the Authorization header value in the current request context."""
14+
authorization_header.set(value)

0 commit comments

Comments
 (0)