1616from authlib .integrations .httpx_client import AsyncOAuth2Client
1717from authlib .oauth2 .rfc8414 import AuthorizationServerMetadata , get_well_known_url
1818from fastapi import Request , status
19- from fastapi .responses import HTMLResponse , RedirectResponse
19+ from fastapi .responses import HTMLResponse , RedirectResponse , StreamingResponse
2020from kink import inject
2121from mcp import ClientSession
2222from mcp .client .streamable_http import streamablehttp_client
3333)
3434from agentstack_server .domain .models .user import User
3535from agentstack_server .exceptions import EntityNotFoundError , PlatformError
36- from agentstack_server .service_layer .services .mcp import McpServerResponse
3736from agentstack_server .service_layer .unit_of_work import IUnitOfWorkFactory
3837
3938logger = logging .getLogger (__name__ )
@@ -118,7 +117,7 @@ async def connect_connector(
118117 ) from err
119118 elif isinstance (err , httpx .RequestError ):
120119 raise PlatformError (
121- "Connector must be in connected or auth_required state " ,
120+ "Unable to establish connection with the connector " ,
122121 status_code = status .HTTP_504_GATEWAY_TIMEOUT ,
123122 ) from err
124123 else :
@@ -379,10 +378,14 @@ def client_factory(headers=None, timeout=None, auth=None):
379378 raise excgroup .exceptions [0 ] from excgroup
380379 raise excgroup
381380
382- async def mcp_proxy (self , * , connector_id : UUID , request : Request , user : User | None = None ) -> McpServerResponse :
381+ async def mcp_proxy (self , * , connector_id : UUID , request : Request , user : User | None = None ):
383382 connector = await self .read_connector (connector_id = connector_id , user = user )
384383
385- forward_headers = {key : request .headers [key ] for key in ["accept" , "content-type" ] if key in request .headers }
384+ forward_headers = {
385+ key : request .headers [key ]
386+ for key in ["accept" , "content-type" , "mcp-protocol-version" , "mcp-session-id" , "last-event-id" ]
387+ if key in request .headers
388+ }
386389
387390 exit_stack = AsyncExitStack ()
388391 try :
@@ -399,32 +402,18 @@ async def mcp_proxy(self, *, connector_id: UUID, request: Request, user: User |
399402 and connector .auth .token .token_type == "bearer"
400403 else {}
401404 ),
402- content = await request .body (),
405+ content = request .stream (),
403406 )
404407 )
405408
406- content_type : str | None = response .headers .get ("content-type" )
407- is_stream = content_type .startswith ("text/event-stream" ) if content_type else False
408- common = {
409- "status_code" : response .status_code ,
410- "headers" : response .headers ,
411- "media_type" : content_type if is_stream else None ,
412- }
413- if is_stream :
414-
415- async def stream_fn ():
416- try :
417- async for chunk in response .aiter_bytes ():
418- yield chunk
419- finally :
420- await exit_stack .pop_all ().aclose ()
421-
422- return McpServerResponse (content = None , stream = stream_fn (), ** common )
423- else :
409+ async def stream_fn ():
424410 try :
425- return McpServerResponse (content = await response .aread (), stream = None , ** common )
411+ async for chunk in response .aiter_bytes ():
412+ yield chunk
426413 finally :
427414 await exit_stack .pop_all ().aclose ()
415+
416+ return StreamingResponse (stream_fn (), status_code = response .status_code , headers = response .headers )
428417 except BaseException :
429418 await exit_stack .pop_all ().aclose ()
430419 raise
0 commit comments