diff --git a/apps/agentstack-server/src/agentstack_server/api/routes/connectors.py b/apps/agentstack-server/src/agentstack_server/api/routes/connectors.py index 9ae60bd64..d9079fd29 100644 --- a/apps/agentstack-server/src/agentstack_server/api/routes/connectors.py +++ b/apps/agentstack-server/src/agentstack_server/api/routes/connectors.py @@ -18,7 +18,6 @@ ConnectorPresetResponse, ConnectorResponse, ) -from agentstack_server.api.utils import to_fastapi from agentstack_server.configuration import ConnectorPreset from agentstack_server.domain.models.common import PaginatedResult from agentstack_server.domain.models.connector import Connector @@ -118,8 +117,7 @@ async def mcp( connector_service: ConnectorServiceDependency, user: Annotated[AuthorizedUser, Depends(RequiresPermissions(connectors={"proxy"}))], ): - response = await connector_service.mcp_proxy(connector_id=connector_id, request=request, user=user.user) - return to_fastapi(response) + return await connector_service.mcp_proxy(connector_id=connector_id, request=request, user=user.user) @router.get("/oauth/callback") diff --git a/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py b/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py index 766d9ca6a..e5faca55d 100644 --- a/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py +++ b/apps/agentstack-server/src/agentstack_server/service_layer/services/connector.py @@ -16,7 +16,7 @@ from authlib.integrations.httpx_client import AsyncOAuth2Client from authlib.oauth2.rfc8414 import AuthorizationServerMetadata, get_well_known_url from fastapi import Request, status -from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse from kink import inject from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client @@ -33,7 +33,6 @@ ) from agentstack_server.domain.models.user import User from agentstack_server.exceptions import EntityNotFoundError, PlatformError -from agentstack_server.service_layer.services.mcp import McpServerResponse from agentstack_server.service_layer.unit_of_work import IUnitOfWorkFactory logger = logging.getLogger(__name__) @@ -118,7 +117,7 @@ async def connect_connector( ) from err elif isinstance(err, httpx.RequestError): raise PlatformError( - "Connector must be in connected or auth_required state", + "Unable to establish connection with the connector", status_code=status.HTTP_504_GATEWAY_TIMEOUT, ) from err else: @@ -379,10 +378,14 @@ def client_factory(headers=None, timeout=None, auth=None): raise excgroup.exceptions[0] from excgroup raise excgroup - async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User | None = None) -> McpServerResponse: + async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User | None = None): connector = await self.read_connector(connector_id=connector_id, user=user) - forward_headers = {key: request.headers[key] for key in ["accept", "content-type"] if key in request.headers} + forward_headers = { + key: request.headers[key] + for key in ["accept", "content-type", "mcp-protocol-version", "mcp-session-id", "last-event-id"] + if key in request.headers + } exit_stack = AsyncExitStack() try: @@ -399,32 +402,18 @@ async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User | and connector.auth.token.token_type == "bearer" else {} ), - content=await request.body(), + content=request.stream(), ) ) - content_type: str | None = response.headers.get("content-type") - is_stream = content_type.startswith("text/event-stream") if content_type else False - common = { - "status_code": response.status_code, - "headers": response.headers, - "media_type": content_type if is_stream else None, - } - if is_stream: - - async def stream_fn(): - try: - async for chunk in response.aiter_bytes(): - yield chunk - finally: - await exit_stack.pop_all().aclose() - - return McpServerResponse(content=None, stream=stream_fn(), **common) - else: + async def stream_fn(): try: - return McpServerResponse(content=await response.aread(), stream=None, **common) + async for chunk in response.aiter_bytes(): + yield chunk finally: await exit_stack.pop_all().aclose() + + return StreamingResponse(stream_fn(), status_code=response.status_code, headers=response.headers) except BaseException: await exit_stack.pop_all().aclose() raise