diff --git a/apps/agentstack-server/src/agentstack_server/domain/models/connector.py b/apps/agentstack-server/src/agentstack_server/domain/models/connector.py index ade3b8a64..cbd736b41 100644 --- a/apps/agentstack-server/src/agentstack_server/domain/models/connector.py +++ b/apps/agentstack-server/src/agentstack_server/domain/models/connector.py @@ -18,6 +18,7 @@ class AuthorizationCodeFlow(BaseModel): code_verifier: str redirect_uri: str client_redirect_uri: AnyUrl | None + resource: str AuthFlow = Annotated[AuthorizationCodeFlow, Field(discriminator="type")] diff --git a/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py b/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py index c2390bab6..98549551d 100644 --- a/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py +++ b/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py @@ -20,7 +20,7 @@ from kink import inject from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client -from pydantic import AnyUrl, BaseModel +from pydantic import AnyUrl, BaseModel, Field from agentstack_server.configuration import Configuration, ConnectorPreset from agentstack_server.domain.models.common import Metadata @@ -34,6 +34,7 @@ from agentstack_server.domain.models.user import User from agentstack_server.exceptions import EntityNotFoundError, PlatformError from agentstack_server.service_layer.unit_of_work import IUnitOfWorkFactory +from agentstack_server.utils.oauth import parse_bearer_mcp_www_authenticate logger = logging.getLogger(__name__) @@ -107,7 +108,10 @@ async def connect_connector( if isinstance(err, httpx.HTTPStatusError): if err.response.status_code == status.HTTP_401_UNAUTHORIZED: await self._bootstrap_auth( - connector=connector, callback_url=callback_uri, redirect_url=redirect_url + connector=connector, + callback_url=callback_uri, + redirect_url=redirect_url, + www_authenticate=err.response.headers.get("www-authenticate"), ) connector.state = ConnectorState.auth_required else: @@ -175,9 +179,10 @@ async def oauth_callback(self, *, callback_url: str, state: str, error: str | No ) async with self._create_oauth_client(connector=connector) as client: - auth_metadata = await self._discover_auth_metadata(connector=connector) - if not auth_metadata: - raise RuntimeError("Authorization server no longer contains necessary metadata") + metadata = await self._discover_connector_metadata(connector=connector) + if not metadata: + raise RuntimeError("Connector no longer contains necessary metadata") + auth_metadata, _ = metadata token_endpoint = auth_metadata.get("token_endpoint") if not token_endpoint: raise RuntimeError("Authorization server has no token endpoint in metadata") @@ -186,6 +191,7 @@ async def oauth_callback(self, *, callback_url: str, state: str, error: str | No authorization_response=callback_url, code_verifier=connector.auth.flow.code_verifier, redirect_uri=connector.auth.flow.redirect_uri, + resource=connector.auth.flow.resource, ) connector.auth.token = Token.model_validate(token) connector.auth.token_endpoint = AnyUrl(str(token_endpoint)) @@ -264,10 +270,30 @@ def _find_preset(self, *, url: AnyUrl) -> ConnectorPreset | None: return preset return None - async def _bootstrap_auth(self, *, connector: Connector, callback_url: str, redirect_url: AnyUrl | None) -> None: - auth_metadata = await self._discover_auth_metadata(connector=connector) - if not auth_metadata: - raise RuntimeError("Not authorization server found for the connector") + async def _bootstrap_auth( + self, + *, + connector: Connector, + callback_url: str, + redirect_url: AnyUrl | None, + www_authenticate: str | None = None, + ) -> None: + resource_metadata_url: str | None = None + scope: str | None = None + if www_authenticate: + try: + parsed = parse_bearer_mcp_www_authenticate(www_authenticate) + resource_metadata_url = parsed.get("resource_metadata") + scope = parsed.get("scope") + except Exception: + logger.warning(f"Failed to parse www-authenticate header: {www_authenticate}", exc_info=True) + + metadata = await self._discover_connector_metadata( + connector=connector, resource_metadata_url=resource_metadata_url + ) + if not metadata: + raise RuntimeError("No metadata found for the connector") + auth_metadata, resource_metadata = metadata if not connector.auth: connector.auth = Authorization() @@ -279,7 +305,12 @@ async def _bootstrap_auth(self, *, connector: Connector, callback_url: str, redi async with self._create_oauth_client(connector=connector) as client: uri, state = client.create_authorization_url( - auth_metadata.get("authorization_endpoint"), code_verifier=code_verifier, redirect_uri=callback_url + auth_metadata.get("authorization_endpoint"), + code_verifier=code_verifier, + redirect_uri=callback_url, + resource=resource_metadata.resource, + scope=scope + or (" ".join(resource_metadata.scopes_supported) if resource_metadata.scopes_supported else None), ) connector.auth.flow = AuthorizationCodeFlow( authorization_endpoint=uri, @@ -287,6 +318,7 @@ async def _bootstrap_auth(self, *, connector: Connector, callback_url: str, redi code_verifier=code_verifier, redirect_uri=callback_url, client_redirect_uri=redirect_url, + resource=resource_metadata.resource, ) async def _revoke_auth_token(self, *, connector: Connector) -> None: @@ -296,9 +328,10 @@ async def _revoke_auth_token(self, *, connector: Connector) -> None: if connector.auth.token: try: async with self._create_oauth_client(connector=connector) as client: - auth_metadata = await self._discover_auth_metadata(connector=connector) - if not auth_metadata: + metadata = await self._discover_connector_metadata(connector=connector) + if not metadata: raise RuntimeError("Authorization server no longer contains necessary metadata") + auth_metadata, _ = metadata revoke_endpoint = auth_metadata.get("revocation_endpoint") if not isinstance(revoke_endpoint, str): raise RuntimeError("Authorization server does not support token revocation") @@ -346,12 +379,16 @@ async def update_token(token, refresh_token=None, access_token=None): token_endpoint=str(connector.auth.token_endpoint), ) - async def _discover_auth_metadata(self, *, connector: Connector) -> AuthorizationServerMetadata | None: - resource_metadata = await _discover_resource_metadata(str(connector.url)) + async def _discover_connector_metadata( + self, *, connector: Connector, resource_metadata_url: str | None = None + ) -> tuple[AuthorizationServerMetadata, _ResourceServerMetadata] | None: + resource_metadata = await _discover_resource_metadata(resource_metadata_url or str(connector.url)) if not resource_metadata or not resource_metadata.authorization_servers: return None auth_metadata = await _discover_auth_metadata(resource_metadata.authorization_servers[0]) - return auth_metadata + if not auth_metadata: + return None + return auth_metadata, resource_metadata async def _ensure_oauth_client_registered(self, *, connector: Connector, redirect_uri: str) -> Connector: if not connector.auth: @@ -531,7 +568,9 @@ def _render_failure(error: str, error_description: str | None): class _ResourceServerMetadata(BaseModel): - authorization_servers: list[str] + resource: str + authorization_servers: list[str] = Field(default_factory=list) + scopes_supported: list[str] = Field(default_factory=list) class _ClientRegistrationResponse(BaseModel): diff --git a/apps/agentstack-server/src/agentstack_server/utils/oauth.py b/apps/agentstack-server/src/agentstack_server/utils/oauth.py new file mode 100644 index 000000000..c6c0e77b5 --- /dev/null +++ b/apps/agentstack-server/src/agentstack_server/utils/oauth.py @@ -0,0 +1,39 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + + +import re + + +def parse_bearer_mcp_www_authenticate(header: str) -> dict[str, str]: + """ + Parses a WWW-Authenticate header like: + Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read" + + Returns a dict with the scheme and all parameters. + """ + # Normalize: remove extra spaces/tabs/newlines + header = header.strip() + + # Match the scheme (Bearer) and the rest + match = re.match(r"^(\w+)\s+(.*)$", header, re.IGNORECASE) + if not match: + raise ValueError("Invalid WWW-Authenticate header") + + scheme = match.group(1).strip() + params_part = match.group(2) + + if scheme.lower() != "bearer": + raise ValueError("Not a bearer scheme") + + # Extract all key="value" pairs (values are quoted) + params = {} + for k, v in re.findall(r'(\w+(?:_\w+)?)="([^"]*)"', params_part): + params[k] = v + + # Also catch any unquoted values that might slip through (rare) + for k, v in re.findall(r"(\w+(?:_\w+)?)=([^,\s]+)", params_part): + if k not in params: + params[k] = v + + return params