diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index 7518f814..1a66e1ab 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -740,6 +740,129 @@ describe("useConnection", () => { }); }); + describe("MCP_PROXY_FULL_ADDRESS Configuration", () => { + beforeEach(() => { + jest.clearAllMocks(); + // Reset the mock transport objects + mockSSETransport.url = undefined; + mockSSETransport.options = undefined; + mockStreamableHTTPTransport.url = undefined; + mockStreamableHTTPTransport.options = undefined; + }); + + test("sends proxyFullAddress query parameter for stdio transport when configured", async () => { + const propsWithProxyFullAddress = { + ...defaultProps, + transportType: "stdio" as const, + command: "test-command", + args: "test-args", + env: {}, + config: { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: "https://example.com/inspector/mcp_proxy", + }, + }, + }; + + const { result } = renderHook(() => + useConnection(propsWithProxyFullAddress), + ); + + await act(async () => { + await result.current.connect(); + }); + + // Check that the URL contains the proxyFullAddress parameter + expect(mockSSETransport.url?.searchParams.get("proxyFullAddress")).toBe( + "https://example.com/inspector/mcp_proxy", + ); + }); + + test("sends proxyFullAddress query parameter for sse transport when configured", async () => { + const propsWithProxyFullAddress = { + ...defaultProps, + transportType: "sse" as const, + sseUrl: "http://localhost:8080", + config: { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: "https://example.com/inspector/mcp_proxy", + }, + }, + }; + + const { result } = renderHook(() => + useConnection(propsWithProxyFullAddress), + ); + + await act(async () => { + await result.current.connect(); + }); + + // Check that the URL contains the proxyFullAddress parameter + expect(mockSSETransport.url?.searchParams.get("proxyFullAddress")).toBe( + "https://example.com/inspector/mcp_proxy", + ); + }); + + test("does not send proxyFullAddress parameter when MCP_PROXY_FULL_ADDRESS is empty", async () => { + const propsWithEmptyProxy = { + ...defaultProps, + transportType: "stdio" as const, + command: "test-command", + args: "test-args", + env: {}, + config: { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: "", + }, + }, + }; + + const { result } = renderHook(() => useConnection(propsWithEmptyProxy)); + + await act(async () => { + await result.current.connect(); + }); + + // Check that the URL does not contain the proxyFullAddress parameter + expect( + mockSSETransport.url?.searchParams.get("proxyFullAddress"), + ).toBeNull(); + }); + + test("does not send proxyFullAddress parameter for streamable-http transport", async () => { + const propsWithStreamableHttp = { + ...defaultProps, + transportType: "streamable-http" as const, + sseUrl: "http://localhost:8080", + config: { + ...DEFAULT_INSPECTOR_CONFIG, + MCP_PROXY_FULL_ADDRESS: { + ...DEFAULT_INSPECTOR_CONFIG.MCP_PROXY_FULL_ADDRESS, + value: "https://example.com/inspector/mcp_proxy", + }, + }, + }; + + const { result } = renderHook(() => + useConnection(propsWithStreamableHttp), + ); + + await act(async () => { + await result.current.connect(); + }); + + // Check that streamable-http transport doesn't get proxyFullAddress parameter + expect( + mockStreamableHTTPTransport.url?.searchParams.get("proxyFullAddress"), + ).toBeNull(); + describe("OAuth Error Handling with Scope Discovery", () => { beforeEach(() => { jest.clearAllMocks(); @@ -875,6 +998,7 @@ describe("useConnection", () => { serverUrl: defaultProps.sseUrl, scope: undefined, }); + }); }); }); diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index e6008197..c95bed91 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -408,11 +408,20 @@ export function useConnection({ let mcpProxyServerUrl; switch (transportType) { - case "stdio": + case "stdio": { mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/stdio`); mcpProxyServerUrl.searchParams.append("command", command); mcpProxyServerUrl.searchParams.append("args", args); mcpProxyServerUrl.searchParams.append("env", JSON.stringify(env)); + + const proxyFullAddress = config.MCP_PROXY_FULL_ADDRESS + .value as string; + if (proxyFullAddress) { + mcpProxyServerUrl.searchParams.append( + "proxyFullAddress", + proxyFullAddress, + ); + } transportOptions = { authProvider: serverAuthProvider, eventSourceInit: { @@ -430,10 +439,20 @@ export function useConnection({ }, }; break; + } - case "sse": + case "sse": { mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/sse`); mcpProxyServerUrl.searchParams.append("url", sseUrl); + + const proxyFullAddressSSE = config.MCP_PROXY_FULL_ADDRESS + .value as string; + if (proxyFullAddressSSE) { + mcpProxyServerUrl.searchParams.append( + "proxyFullAddress", + proxyFullAddressSSE, + ); + } transportOptions = { eventSourceInit: { fetch: ( @@ -450,6 +469,7 @@ export function useConnection({ }, }; break; + } case "streamable-http": mcpProxyServerUrl = new URL(`${getMCPProxyAddress(config)}/mcp`); diff --git a/server/src/index.ts b/server/src/index.ts index 0a0f7bcc..bb3a97df 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -378,7 +378,6 @@ app.get( let serverTransport: Transport | undefined; try { serverTransport = await createTransport(req); - console.log("Created server transport"); } catch (error) { if (error instanceof SseError && error.code === 401) { console.error( @@ -391,11 +390,16 @@ app.get( throw error; } - const webAppTransport = new SSEServerTransport("/message", res); - console.log("Created client transport"); + const proxyFullAddress = (req.query.proxyFullAddress as string) || ""; + const prefix = proxyFullAddress || ""; + const endpoint = `${prefix}/message`; + const webAppTransport = new SSEServerTransport(endpoint, res); webAppTransports.set(webAppTransport.sessionId, webAppTransport); + console.log("Created client transport"); + serverTransports.set(webAppTransport.sessionId, serverTransport); + console.log("Created server transport"); await webAppTransport.start(); @@ -442,7 +446,7 @@ app.get( async (req, res) => { try { console.log( - "New SSE connection request. NOTE: The sse transport is deprecated and has been replaced by StreamableHttp", + "New SSE connection request. NOTE: The SSE transport is deprecated and has been replaced by StreamableHttp", ); let serverTransport: Transport | undefined; try { @@ -469,9 +473,14 @@ app.get( } if (serverTransport) { - const webAppTransport = new SSEServerTransport("/message", res); + const proxyFullAddress = (req.query.proxyFullAddress as string) || ""; + const prefix = proxyFullAddress || ""; + const endpoint = `${prefix}/message`; + + const webAppTransport = new SSEServerTransport(endpoint, res); webAppTransports.set(webAppTransport.sessionId, webAppTransport); console.log("Created client transport"); + serverTransports.set(webAppTransport.sessionId, serverTransport!); console.log("Created server transport");