Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ConnectorPresetResponse,
ConnectorResponse,
)
from agentstack_server.api.utils import to_fastapi
from agentstack_server.configuration import ConnectorPreset
from agentstack_server.domain.models.common import PaginatedResult
from agentstack_server.domain.models.connector import Connector
Expand Down Expand Up @@ -118,8 +117,7 @@ async def mcp(
connector_service: ConnectorServiceDependency,
user: Annotated[AuthorizedUser, Depends(RequiresPermissions(connectors={"proxy"}))],
):
response = await connector_service.mcp_proxy(connector_id=connector_id, request=request, user=user.user)
return to_fastapi(response)
return await connector_service.mcp_proxy(connector_id=connector_id, request=request, user=user.user)


@router.get("/oauth/callback")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from authlib.integrations.httpx_client import AsyncOAuth2Client
from authlib.oauth2.rfc8414 import AuthorizationServerMetadata, get_well_known_url
from fastapi import Request, status
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse
from kink import inject
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
Expand All @@ -33,7 +33,6 @@
)
from agentstack_server.domain.models.user import User
from agentstack_server.exceptions import EntityNotFoundError, PlatformError
from agentstack_server.service_layer.services.mcp import McpServerResponse
from agentstack_server.service_layer.unit_of_work import IUnitOfWorkFactory

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,7 +117,7 @@ async def connect_connector(
) from err
elif isinstance(err, httpx.RequestError):
raise PlatformError(
"Connector must be in connected or auth_required state",
"Unable to establish connection with the connector",
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
) from err
else:
Expand Down Expand Up @@ -379,10 +378,14 @@ def client_factory(headers=None, timeout=None, auth=None):
raise excgroup.exceptions[0] from excgroup
raise excgroup

async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User | None = None) -> McpServerResponse:
async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User | None = None):
connector = await self.read_connector(connector_id=connector_id, user=user)

forward_headers = {key: request.headers[key] for key in ["accept", "content-type"] if key in request.headers}
forward_headers = {
key: request.headers[key]
for key in ["accept", "content-type", "mcp-protocol-version", "mcp-session-id", "last-event-id"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This list of headers is getting long. For better readability and maintainability, consider defining it as a module-level constant (e.g., a frozenset for efficient lookups).

For example:

# At the top of the file
_FORWARDED_MCP_HEADERS = frozenset((
    "accept",
    "content-type",
    "mcp-protocol-version",
    "mcp-session-id",
    "last-event-id",
))

# In the method:
# ...
for key in _FORWARDED_MCP_HEADERS:
# ...

if key in request.headers
}

exit_stack = AsyncExitStack()
try:
Expand All @@ -399,32 +402,18 @@ async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User |
and connector.auth.token.token_type == "bearer"
else {}
),
content=await request.body(),
content=request.stream(),
)
)

content_type: str | None = response.headers.get("content-type")
is_stream = content_type.startswith("text/event-stream") if content_type else False
common = {
"status_code": response.status_code,
"headers": response.headers,
"media_type": content_type if is_stream else None,
}
if is_stream:

async def stream_fn():
try:
async for chunk in response.aiter_bytes():
yield chunk
finally:
await exit_stack.pop_all().aclose()

return McpServerResponse(content=None, stream=stream_fn(), **common)
else:
async def stream_fn():
try:
return McpServerResponse(content=await response.aread(), stream=None, **common)
async for chunk in response.aiter_bytes():
yield chunk
finally:
await exit_stack.pop_all().aclose()

return StreamingResponse(stream_fn(), status_code=response.status_code, headers=response.headers)
except BaseException:
await exit_stack.pop_all().aclose()
raise
Expand Down
Loading