Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class AuthorizationCodeFlow(BaseModel):
code_verifier: str
redirect_uri: str
client_redirect_uri: AnyUrl | None
resource: str


AuthFlow = Annotated[AuthorizationCodeFlow, Field(discriminator="type")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from kink import inject
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from pydantic import AnyUrl, BaseModel
from pydantic import AnyUrl, BaseModel, Field

from agentstack_server.configuration import Configuration, ConnectorPreset
from agentstack_server.domain.models.common import Metadata
Expand All @@ -34,6 +34,7 @@
from agentstack_server.domain.models.user import User
from agentstack_server.exceptions import EntityNotFoundError, PlatformError
from agentstack_server.service_layer.unit_of_work import IUnitOfWorkFactory
from agentstack_server.utils.oauth import parse_bearer_mcp_www_authenticate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,7 +108,10 @@ async def connect_connector(
if isinstance(err, httpx.HTTPStatusError):
if err.response.status_code == status.HTTP_401_UNAUTHORIZED:
await self._bootstrap_auth(
connector=connector, callback_url=callback_uri, redirect_url=redirect_url
connector=connector,
callback_url=callback_uri,
redirect_url=redirect_url,
www_authenticate=err.response.headers.get("www-authenticate"),
)
connector.state = ConnectorState.auth_required
else:
Expand Down Expand Up @@ -175,9 +179,10 @@ async def oauth_callback(self, *, callback_url: str, state: str, error: str | No
)

async with self._create_oauth_client(connector=connector) as client:
auth_metadata = await self._discover_auth_metadata(connector=connector)
if not auth_metadata:
raise RuntimeError("Authorization server no longer contains necessary metadata")
metadata = await self._discover_connector_metadata(connector=connector)
if not metadata:
raise RuntimeError("Connector no longer contains necessary metadata")
auth_metadata, _ = metadata
token_endpoint = auth_metadata.get("token_endpoint")
if not token_endpoint:
raise RuntimeError("Authorization server has no token endpoint in metadata")
Expand All @@ -186,6 +191,7 @@ async def oauth_callback(self, *, callback_url: str, state: str, error: str | No
authorization_response=callback_url,
code_verifier=connector.auth.flow.code_verifier,
redirect_uri=connector.auth.flow.redirect_uri,
resource=connector.auth.flow.resource,
)
connector.auth.token = Token.model_validate(token)
connector.auth.token_endpoint = AnyUrl(str(token_endpoint))
Expand Down Expand Up @@ -264,10 +270,30 @@ def _find_preset(self, *, url: AnyUrl) -> ConnectorPreset | None:
return preset
return None

async def _bootstrap_auth(self, *, connector: Connector, callback_url: str, redirect_url: AnyUrl | None) -> None:
auth_metadata = await self._discover_auth_metadata(connector=connector)
if not auth_metadata:
raise RuntimeError("Not authorization server found for the connector")
async def _bootstrap_auth(
self,
*,
connector: Connector,
callback_url: str,
redirect_url: AnyUrl | None,
www_authenticate: str | None = None,
) -> None:
resource_metadata_url: str | None = None
scope: str | None = None
if www_authenticate:
try:
parsed = parse_bearer_mcp_www_authenticate(www_authenticate)
resource_metadata_url = parsed.get("resource_metadata")
scope = parsed.get("scope")
except Exception:
logger.warning(f"Failed to parse www-authenticate header: {www_authenticate}", exc_info=True)

metadata = await self._discover_connector_metadata(
connector=connector, resource_metadata_url=resource_metadata_url
)
if not metadata:
raise RuntimeError("No metadata found for the connector")
auth_metadata, resource_metadata = metadata

if not connector.auth:
connector.auth = Authorization()
Expand All @@ -279,14 +305,20 @@ async def _bootstrap_auth(self, *, connector: Connector, callback_url: str, redi

async with self._create_oauth_client(connector=connector) as client:
uri, state = client.create_authorization_url(
auth_metadata.get("authorization_endpoint"), code_verifier=code_verifier, redirect_uri=callback_url
auth_metadata.get("authorization_endpoint"),
code_verifier=code_verifier,
redirect_uri=callback_url,
resource=resource_metadata.resource,
scope=scope
or (" ".join(resource_metadata.scopes_supported) if resource_metadata.scopes_supported else None),
)
connector.auth.flow = AuthorizationCodeFlow(
authorization_endpoint=uri,
state=state,
code_verifier=code_verifier,
redirect_uri=callback_url,
client_redirect_uri=redirect_url,
resource=resource_metadata.resource,
)

async def _revoke_auth_token(self, *, connector: Connector) -> None:
Expand All @@ -296,9 +328,10 @@ async def _revoke_auth_token(self, *, connector: Connector) -> None:
if connector.auth.token:
try:
async with self._create_oauth_client(connector=connector) as client:
auth_metadata = await self._discover_auth_metadata(connector=connector)
if not auth_metadata:
metadata = await self._discover_connector_metadata(connector=connector)
if not metadata:
raise RuntimeError("Authorization server no longer contains necessary metadata")
auth_metadata, _ = metadata
revoke_endpoint = auth_metadata.get("revocation_endpoint")
if not isinstance(revoke_endpoint, str):
raise RuntimeError("Authorization server does not support token revocation")
Expand Down Expand Up @@ -346,12 +379,16 @@ async def update_token(token, refresh_token=None, access_token=None):
token_endpoint=str(connector.auth.token_endpoint),
)

async def _discover_auth_metadata(self, *, connector: Connector) -> AuthorizationServerMetadata | None:
resource_metadata = await _discover_resource_metadata(str(connector.url))
async def _discover_connector_metadata(
self, *, connector: Connector, resource_metadata_url: str | None = None
) -> tuple[AuthorizationServerMetadata, _ResourceServerMetadata] | None:
resource_metadata = await _discover_resource_metadata(resource_metadata_url or str(connector.url))
if not resource_metadata or not resource_metadata.authorization_servers:
return None
auth_metadata = await _discover_auth_metadata(resource_metadata.authorization_servers[0])
return auth_metadata
if not auth_metadata:
return None
return auth_metadata, resource_metadata

async def _ensure_oauth_client_registered(self, *, connector: Connector, redirect_uri: str) -> Connector:
if not connector.auth:
Expand Down Expand Up @@ -531,7 +568,9 @@ def _render_failure(error: str, error_description: str | None):


class _ResourceServerMetadata(BaseModel):
authorization_servers: list[str]
resource: str
authorization_servers: list[str] = Field(default_factory=list)
scopes_supported: list[str] = Field(default_factory=list)


class _ClientRegistrationResponse(BaseModel):
Expand Down
39 changes: 39 additions & 0 deletions apps/agentstack-server/src/agentstack_server/utils/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0


import re


def parse_bearer_mcp_www_authenticate(header: str) -> dict[str, str]:
"""
Parses a WWW-Authenticate header like:
Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read"

Returns a dict with the scheme and all parameters.
"""
# Normalize: remove extra spaces/tabs/newlines
header = header.strip()

# Match the scheme (Bearer) and the rest
match = re.match(r"^(\w+)\s+(.*)$", header, re.IGNORECASE)
if not match:
raise ValueError("Invalid WWW-Authenticate header")

scheme = match.group(1).strip()
params_part = match.group(2)

if scheme.lower() != "bearer":
raise ValueError("Not a bearer scheme")

# Extract all key="value" pairs (values are quoted)
params = {}
for k, v in re.findall(r'(\w+(?:_\w+)?)="([^"]*)"', params_part):
params[k] = v

# Also catch any unquoted values that might slip through (rare)
for k, v in re.findall(r"(\w+(?:_\w+)?)=([^,\s]+)", params_part):
if k not in params:
params[k] = v

return params
Loading