Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def connect_connector(
) from err
elif isinstance(err, httpx.RequestError):
raise PlatformError(
"Connector must be in connected or auth_required state",
"Unable to establish connection with the connector",
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
) from err
else:
Expand Down Expand Up @@ -382,7 +382,11 @@ def client_factory(headers=None, timeout=None, auth=None):
async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User | None = None) -> McpServerResponse:
connector = await self.read_connector(connector_id=connector_id, user=user)

forward_headers = {key: request.headers[key] for key in ["accept", "content-type"] if key in request.headers}
forward_headers = {
key: request.headers[key]
for key in ["accept", "content-type", "mcp-protocol-version", "mcp-session-id", "last-event-id"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This list of headers is getting long. For better readability and maintainability, consider defining it as a module-level constant (e.g., a frozenset for efficient lookups).

For example:

# At the top of the file
_FORWARDED_MCP_HEADERS = frozenset((
    "accept",
    "content-type",
    "mcp-protocol-version",
    "mcp-session-id",
    "last-event-id",
))

# In the method:
# ...
for key in _FORWARDED_MCP_HEADERS:
# ...

if key in request.headers
}

exit_stack = AsyncExitStack()
try:
Expand All @@ -404,11 +408,13 @@ async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User |
)

content_type: str | None = response.headers.get("content-type")
session_id: str | None = response.headers.get("mcp-session-id")
is_stream = content_type.startswith("text/event-stream") if content_type else False

common = {
"status_code": response.status_code,
"headers": response.headers,
"media_type": content_type if is_stream else None,
"headers": {"mcp-session-id": session_id} if session_id else {},
"media_type": content_type,
}
if is_stream:

Expand Down
Loading