diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index b4d6705b8..ccd929650 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -1,7 +1,10 @@ import { renderHook, act } from "@testing-library/react"; import { useConnection } from "../useConnection"; import { z } from "zod"; -import { ClientRequest } from "@modelcontextprotocol/sdk/types.js"; +import { + ClientRequest, + JSONRPCMessage, +} from "@modelcontextprotocol/sdk/types.js"; import { DEFAULT_INSPECTOR_CONFIG, CLIENT_IDENTITY } from "../../constants"; import { SSEClientTransportOptions, @@ -42,10 +45,12 @@ const mockSSETransport: { start: jest.Mock; url: URL | undefined; options: SSEClientTransportOptions | undefined; + onmessage?: (message: JSONRPCMessage) => void; } = { start: jest.fn(), url: undefined, options: undefined, + onmessage: undefined, }; const mockStreamableHTTPTransport: { @@ -482,6 +487,129 @@ describe("useConnection", () => { }); }); + describe("Ref Resolution", () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test("resolves $ref references in requestedSchema properties before validation", async () => { + const mockProtocolOnMessage = jest.fn(); + + mockSSETransport.onmessage = mockProtocolOnMessage; + + const { result } = renderHook(() => useConnection(defaultProps)); + + await act(async () => { + await result.current.connect(); + }); + + const mockRequestWithRef: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "elicitation/create", + params: { + message: "Please provide your information", + requestedSchema: { + type: "object", + properties: { + source: { + type: "string", + minLength: 1, + title: "A Connectable Node", + }, + target: { + $ref: "#/properties/source", + }, + }, + }, + }, + }; + + await act(async () => { + mockSSETransport.onmessage!(mockRequestWithRef); + }); + + expect(mockProtocolOnMessage).toHaveBeenCalledTimes(1); + + const message = mockProtocolOnMessage.mock.calls[0][0]; + expect(message.params.requestedSchema.properties.target).toEqual({ + type: "string", + minLength: 1, + title: "A Connectable Node", + }); + }); + + test("resolves $ref references to $defs in requestedSchema", async () => { + const mockProtocolOnMessage = jest.fn(); + + mockSSETransport.onmessage = mockProtocolOnMessage; + + const { result } = renderHook(() => useConnection(defaultProps)); + + await act(async () => { + await result.current.connect(); + }); + + const mockRequestWithDefs: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "elicitation/create", + params: { + message: "Please provide your information", + requestedSchema: { + type: "object", + properties: { + user: { + $ref: "#/$defs/UserInput", + }, + }, + $defs: { + UserInput: { + type: "object", + properties: { + name: { + type: "string", + title: "Name", + }, + age: { + type: "integer", + title: "Age", + minimum: 0, + }, + }, + required: ["name"], + }, + }, + }, + }, + }; + + await act(async () => { + mockSSETransport.onmessage!(mockRequestWithDefs); + }); + + expect(mockProtocolOnMessage).toHaveBeenCalledTimes(1); + + const message = mockProtocolOnMessage.mock.calls[0][0]; + // The $ref should be resolved to the actual UserInput definition + expect(message.params.requestedSchema.properties.user).toEqual({ + type: "object", + properties: { + name: { + type: "string", + title: "Name", + }, + age: { + type: "integer", + title: "Age", + minimum: 0, + }, + }, + required: ["name"], + }); + }); + }); + describe("URL Port Handling", () => { const SSEClientTransport = jest.requireMock( "@modelcontextprotocol/sdk/client/sse.js", diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index bd15e080a..e95cad97d 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -59,6 +59,7 @@ import { getMCPServerRequestTimeout } from "@/utils/configUtils"; import { InspectorConfig } from "../configurationTypes"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import { CustomHeaders } from "../types/customHeaders"; +import { resolveRefsInMessage } from "@/utils/schemaUtils"; interface UseConnectionOptions { transportType: "stdio" | "sse" | "streamable-http"; @@ -691,6 +692,14 @@ export function useConnection({ await client.connect(transport as Transport); + const protocolOnMessage = transport.onmessage; + if (protocolOnMessage) { + transport.onmessage = (message) => { + const resolvedMessage = resolveRefsInMessage(message); + protocolOnMessage(resolvedMessage); + }; + } + setClientTransport(transport); capabilities = client.getServerCapabilities(); diff --git a/client/src/utils/schemaUtils.ts b/client/src/utils/schemaUtils.ts index 365870216..42d77d439 100644 --- a/client/src/utils/schemaUtils.ts +++ b/client/src/utils/schemaUtils.ts @@ -1,7 +1,8 @@ import type { JsonValue, JsonSchemaType, JsonObject } from "./jsonUtils"; import Ajv from "ajv"; import type { ValidateFunction } from "ajv"; -import type { Tool } from "@modelcontextprotocol/sdk/types.js"; +import type { Tool, JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { isJSONRPCRequest } from "@modelcontextprotocol/sdk/types.js"; const ajv = new Ajv(); @@ -299,3 +300,41 @@ export function formatFieldLabel(key: string): string { .replace(/_/g, " ") // Replace underscores with spaces .replace(/^\w/, (c) => c.toUpperCase()); // Capitalize first letter } + +/** + * Resolves `$ref` references in a JSON-RPC "elicitation/create" message's `requestedSchema` field + * @param message The JSON-RPC message that may contain $ref references + * @returns A new message with resolved $ref references, or the original message if no resolution is needed + */ +export function resolveRefsInMessage(message: JSONRPCMessage): JSONRPCMessage { + if (!isJSONRPCRequest(message) || !message.params?.requestedSchema) { + return message; + } + + const requestedSchema = message.params.requestedSchema as JsonSchemaType; + + if (!requestedSchema?.properties) { + return message; + } + + const resolvedMessage = { + ...message, + params: { + ...message.params, + requestedSchema: { + ...requestedSchema, + properties: Object.fromEntries( + Object.entries(requestedSchema.properties).map( + ([key, propSchema]) => { + const resolved = resolveRef(propSchema, requestedSchema); + const normalized = normalizeUnionType(resolved); + return [key, normalized]; + }, + ), + ), + }, + }, + }; + + return resolvedMessage; +}