Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion client/src/lib/hooks/__tests__/useConnection.test.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { renderHook, act } from "@testing-library/react";
import { useConnection } from "../useConnection";
import { z } from "zod";
import { ClientRequest } from "@modelcontextprotocol/sdk/types.js";
import {
ClientRequest,
JSONRPCMessage,
} from "@modelcontextprotocol/sdk/types.js";
import { DEFAULT_INSPECTOR_CONFIG, CLIENT_IDENTITY } from "../../constants";
import {
SSEClientTransportOptions,
Expand Down Expand Up @@ -42,10 +45,12 @@ const mockSSETransport: {
start: jest.Mock;
url: URL | undefined;
options: SSEClientTransportOptions | undefined;
onmessage?: (message: JSONRPCMessage) => void;
} = {
start: jest.fn(),
url: undefined,
options: undefined,
onmessage: undefined,
};

const mockStreamableHTTPTransport: {
Expand Down Expand Up @@ -480,6 +485,57 @@ describe("useConnection", () => {
});
});

describe("Ref Resolution", () => {
beforeEach(() => {
jest.clearAllMocks();
});

test("resolves $ref references in requestedSchema properties before validation", async () => {
const mockProtocolOnMessage = jest.fn();

mockSSETransport.onmessage = mockProtocolOnMessage;

const { result } = renderHook(() => useConnection(defaultProps));

await act(async () => {
await result.current.connect();
});

const mockRequestWithRef: JSONRPCMessage = {
jsonrpc: "2.0",
id: 1,
method: "elicitation/create",
params: {
message: "Please provide your information",
requestedSchema: {
type: "object",
properties: {
name: {
$ref: "#/properties/nameDef",
},
nameDef: {
type: "string",
title: "Name",
},
},
},
},
};

await act(async () => {
mockSSETransport.onmessage!(mockRequestWithRef);
});

expect(mockProtocolOnMessage).toHaveBeenCalledTimes(1);

const message = mockProtocolOnMessage.mock.calls[0][0];
expect(message.params.requestedSchema.properties.name).toEqual({
type: "string",
title: "Name",
});
});
});

describe("URL Port Handling", () => {
const SSEClientTransport = jest.requireMock(
"@modelcontextprotocol/sdk/client/sse.js",
Expand Down
9 changes: 9 additions & 0 deletions client/src/lib/hooks/useConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import { getMCPServerRequestTimeout } from "@/utils/configUtils";
import { InspectorConfig } from "../configurationTypes";
import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js";
import { CustomHeaders } from "../types/customHeaders";
import { resolveRefsInMessage } from "@/utils/schemaUtils";

interface UseConnectionOptions {
transportType: "stdio" | "sse" | "streamable-http";
Expand Down Expand Up @@ -681,6 +682,14 @@ export function useConnection({

await client.connect(transport as Transport);

const protocolOnMessage = transport.onmessage;
if (protocolOnMessage) {
transport.onmessage = (message) => {
const resolvedMessage = resolveRefsInMessage(message);
protocolOnMessage(resolvedMessage);
};
}

setClientTransport(transport);

capabilities = client.getServerCapabilities();
Expand Down
40 changes: 39 additions & 1 deletion client/src/utils/schemaUtils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import type { JsonValue, JsonSchemaType, JsonObject } from "./jsonUtils";
import Ajv from "ajv";
import type { ValidateFunction } from "ajv";
import type { Tool } from "@modelcontextprotocol/sdk/types.js";
import type { Tool, JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
import { isJSONRPCRequest } from "@modelcontextprotocol/sdk/types.js";

const ajv = new Ajv();

Expand Down Expand Up @@ -299,3 +300,40 @@ export function formatFieldLabel(key: string): string {
.replace(/_/g, " ") // Replace underscores with spaces
.replace(/^\w/, (c) => c.toUpperCase()); // Capitalize first letter
}

/**
* Resolves $ref references in a JSON-RPC message's requestedSchema
* @param message The JSON-RPC message that may contain $ref references
* @returns A new message with resolved $ref references, or the original message if no resolution is needed
*/
export function resolveRefsInMessage(message: JSONRPCMessage): JSONRPCMessage {
if (!isJSONRPCRequest(message) || !message.params?.requestedSchema) {
return message;
}

const requestedSchema = message.params.requestedSchema as JsonSchemaType;

if (!requestedSchema?.properties) {
return message;
}

const resolvedMessage = {
...message,
params: {
...message.params,
requestedSchema: {
...requestedSchema,
properties: Object.fromEntries(
Object.entries(requestedSchema.properties).map(
([key, propSchema]) => [
key,
resolveRef(propSchema, requestedSchema),
],
),
),
},
},
};

return resolvedMessage;
}