Skip to content

Commit b7c8b3e

Browse files
authored
fix(server): ensure resource protected metadata RFC compliance (#1200)
Signed-off-by: Tomas Pilar <thomas7pilar@gmail.com>
1 parent 13af6f1 commit b7c8b3e

File tree

5 files changed

+33
-20
lines changed

5 files changed

+33
-20
lines changed

apps/beeai-cli/src/beeai_cli/commands/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def server_login(server: typing.Annotated[str | None, typer.Argument()] =
126126
else:
127127
console.info("No authentication tokens found for this server. Proceeding to log in.")
128128
async with httpx.AsyncClient(verify=await get_verify_option(server)) as client:
129-
resp = await client.get(f"{server}/api/v1/.well-known/oauth-protected-resource")
129+
resp = await client.get(f"{server}/.well-known/oauth-protected-resource")
130130
if resp.is_error:
131131
console.error("This server does not appear to run a compatible version of BeeAI Platform.")
132132
sys.exit(1)

apps/beeai-server/src/beeai_server/api/dependencies.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Annotated
66
from uuid import UUID
77

8-
from fastapi import Depends, HTTPException, Path, Query, Security, status
8+
from fastapi import Depends, HTTPException, Path, Query, Request, Security, status
99
from fastapi.security import APIKeyCookie, HTTPAuthorizationCredentials, HTTPBasic, HTTPBasicCredentials, HTTPBearer
1010
from jwt import PyJWTError
1111
from kink import di
@@ -56,6 +56,7 @@ async def authenticate_oauth_user(
5656
cookie_auth: str | None,
5757
user_service: UserServiceDependency,
5858
configuration: ConfigurationDependency,
59+
request: Request,
5960
) -> AuthorizedUser:
6061
"""
6162
Authenticate using an OIDC/OAuth2 JWT bearer token with JWKS.
@@ -69,8 +70,9 @@ async def authenticate_oauth_user(
6970
detail=f"Invalid Authorization header: {e}",
7071
) from e
7172

73+
expected_audience = str(request.url.replace(path="/"))
7274
claims, issuer = await decode_oauth_jwt_or_introspect(
73-
token=token, jwks_dict=di["JWKS_CACHE"], aud="beeai-server", configuration=configuration
75+
token=token, jwks_dict=di["JWKS_CACHE"], aud=expected_audience, configuration=configuration
7476
)
7577
if not claims:
7678
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token")
@@ -112,6 +114,7 @@ async def authorized_user(
112114
basic_auth: Annotated[HTTPBasicCredentials | None, Depends(HTTPBasic(auto_error=False))],
113115
bearer_auth: Annotated[HTTPAuthorizationCredentials | None, Depends(HTTPBearer(auto_error=False))],
114116
cookie_auth: Annotated[str | None, Security(api_key_cookie)],
117+
request: Request,
115118
) -> AuthorizedUser:
116119
if bearer_auth:
117120
# Check Bearer token first - locally this allows for "checking permissions" for development purposes
@@ -128,12 +131,12 @@ async def authorized_user(
128131
return token
129132
except PyJWTError:
130133
if configuration.auth.oidc.enabled:
131-
return await authenticate_oauth_user(bearer_auth, cookie_auth, user_service, configuration)
134+
return await authenticate_oauth_user(bearer_auth, cookie_auth, user_service, configuration, request)
132135
# TODO: update agents
133136
logger.warning("Bearer token is invalid, agent is not probably not using llm extension correctly")
134137

135138
if configuration.auth.oidc.enabled and cookie_auth:
136-
return await authenticate_oauth_user(bearer_auth, cookie_auth, user_service, configuration)
139+
return await authenticate_oauth_user(bearer_auth, cookie_auth, user_service, configuration, request)
137140

138141
if configuration.auth.basic.enabled:
139142
if basic_auth and basic_auth.password == configuration.auth.basic.admin_password.get_secret_value():

apps/beeai-server/src/beeai_server/api/routes/auth.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
# SPDX-License-Identifier: Apache-2.0
33
import logging
44

5-
from fastapi import APIRouter
5+
from fastapi import APIRouter, Request
66

77
from beeai_server.api.dependencies import AuthServiceDependency
88

99
logger = logging.getLogger(__name__)
1010

11-
router = APIRouter()
11+
well_known_router = APIRouter()
1212

1313

14-
@router.get("/.well-known/oauth-protected-resource")
15-
def protected_resource_metadata(auth_servide: AuthServiceDependency):
16-
return auth_servide.protected_resource_metadata()
14+
@well_known_router.get("/oauth-protected-resource/{resource:path}")
15+
def protected_resource_metadata(
16+
request: Request,
17+
auth_servide: AuthServiceDependency,
18+
resource: str = "",
19+
):
20+
return auth_servide.protected_resource_metadata(resource=str(request.url.replace(path=resource)))

apps/beeai-server/src/beeai_server/application.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from kink import Container, di, inject
1313
from opentelemetry.metrics import CallbackOptions, Observation, get_meter
1414
from starlette.requests import Request
15-
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
15+
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_500_INTERNAL_SERVER_ERROR
1616

1717
from beeai_server.api.routes.a2a import router as a2a_router
18-
from beeai_server.api.routes.auth import router as auth_router
18+
from beeai_server.api.routes.auth import well_known_router as auth_well_known_router
1919
from beeai_server.api.routes.configurations import router as configuration_router
2020
from beeai_server.api.routes.contexts import router as contexts_router
2121
from beeai_server.api.routes.files import router as files_router
@@ -72,12 +72,18 @@ async def custom_http_exception_handler(request: Request, exc):
7272
exception = exc
7373
case _:
7474
exception = HTTPException(HTTP_500_INTERNAL_SERVER_ERROR, detail=repr(extract_messages(exc)))
75+
76+
if isinstance(exception, HTTPException) and exc.status_code == HTTP_401_UNAUTHORIZED:
77+
exception.headers = exception.headers or {}
78+
exception.headers |= {
79+
"WWW-Authenticate": f'Bearer resource_metadata="{request.url.replace(path="/.well-known/oauth-protected-resource")}"' # We don't define multiple resource domains at the moment
80+
}
81+
7582
return await http_exception_handler(request, exception)
7683

7784

7885
def mount_routes(app: FastAPI):
7986
server_router = APIRouter()
80-
server_router.include_router(auth_router, prefix="", tags=["auth"])
8187
server_router.include_router(a2a_router, prefix="/a2a")
8288
server_router.include_router(mcp_router, prefix="/mcp")
8389
server_router.include_router(provider_router, prefix="/providers", tags=["providers"])
@@ -89,7 +95,11 @@ def mount_routes(app: FastAPI):
8995
server_router.include_router(vector_stores_router, prefix="/vector_stores", tags=["vector_stores"])
9096
server_router.include_router(user_feedback_router, prefix="/user_feedback", tags=["user_feedback"])
9197

98+
well_known_router = APIRouter()
99+
well_known_router.include_router(auth_well_known_router, prefix="")
100+
92101
app.include_router(server_router, prefix="/api/v1", tags=["provider"])
102+
app.include_router(well_known_router, prefix="/.well-known", tags=["well-known"])
93103

94104
@app.get("/healthcheck")
95105
async def healthcheck():

apps/beeai-server/src/beeai_server/service_layer/services/auth.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,9 @@ class AuthService:
1515
def __init__(self, configuration: Configuration):
1616
self._config = configuration
1717

18-
def protected_resource_metadata(self) -> dict:
19-
resource_id = f"http://localhost:{self._config.port}" # TODO
20-
providers = self._config.auth.oidc.providers
21-
authorization_server = [str(p.issuer) for p in providers if p.issuer is not None]
22-
18+
def protected_resource_metadata(self, *, resource: str) -> dict:
2319
return {
24-
"resource_id": resource_id,
25-
"authorization_servers": authorization_server,
20+
"resource": resource,
21+
"authorization_servers": [str(p.issuer) for p in self._config.auth.oidc.providers if p.issuer is not None],
2622
"scopes_supported": list(self._config.auth.oidc.scope),
2723
}

0 commit comments

Comments
 (0)