Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions src/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ import {
import { Transport } from "../shared/transport.js";
import { Server } from "../server/index.js";
import { InMemoryTransport } from "../inMemory.js";

import { RequestInfo } from "../server/types/types.js";

const mockRequestInfo: RequestInfo = {
headers: {
'content-type': 'application/json',
'accept': 'application/json',
},
};
/***
* Test: Initialize with Matching Protocol Version
*/
Expand All @@ -43,7 +50,7 @@ test("should initialize with matching protocol version", async () => {
},
instructions: "test instructions",
},
});
}, { requestInfo: mockRequestInfo });
}
return Promise.resolve();
}),
Expand Down Expand Up @@ -101,7 +108,7 @@ test("should initialize with supported older protocol version", async () => {
version: "1.0",
},
},
});
}, { requestInfo: mockRequestInfo });
}
return Promise.resolve();
}),
Expand Down Expand Up @@ -151,7 +158,7 @@ test("should reject unsupported protocol version", async () => {
version: "1.0",
},
},
});
}, { requestInfo: mockRequestInfo });
}
return Promise.resolve();
}),
Expand Down
14 changes: 11 additions & 3 deletions src/server/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ import {
import { Transport } from "../shared/transport.js";
import { InMemoryTransport } from "../inMemory.js";
import { Client } from "../client/index.js";
import { RequestInfo } from "./types/types.js";

const mockRequestInfo: RequestInfo = {
headers: {
'content-type': 'application/json',
'traceparent': '00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01',
},
};

test("should accept latest protocol version", async () => {
let sendPromiseResolve: (value: unknown) => void;
Expand Down Expand Up @@ -78,7 +86,7 @@ test("should accept latest protocol version", async () => {
version: "1.0",
},
},
});
}, { requestInfo: mockRequestInfo });

await expect(sendPromise).resolves.toBeUndefined();
});
Expand Down Expand Up @@ -139,7 +147,7 @@ test("should accept supported older protocol version", async () => {
version: "1.0",
},
},
});
}, { requestInfo: mockRequestInfo });

await expect(sendPromise).resolves.toBeUndefined();
});
Expand Down Expand Up @@ -199,7 +207,7 @@ test("should handle unsupported protocol version", async () => {
version: "1.0",
},
},
});
}, { requestInfo: mockRequestInfo });

await expect(sendPromise).resolves.toBeUndefined();
});
Expand Down
11 changes: 10 additions & 1 deletion src/server/mcp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@ import {
import { ResourceTemplate } from "./mcp.js";
import { completable } from "./completable.js";
import { UriTemplate } from "../shared/uriTemplate.js";
import { RequestInfo } from "./types/types.js";
import { getDisplayName } from "../shared/metadataUtils.js";

const mockRequestInfo: RequestInfo = {
headers: {
'content-type': 'application/json',
'accept': 'application/json',
},
};

describe("McpServer", () => {
/***
* Test: Basic Server Instance
Expand Down Expand Up @@ -214,7 +222,8 @@ describe("ResourceTemplate", () => {
signal: abortController.signal,
requestId: 'not-implemented',
sendRequest: () => { throw new Error("Not implemented") },
sendNotification: () => { throw new Error("Not implemented") }
sendNotification: () => { throw new Error("Not implemented") },
requestInfo: mockRequestInfo
});
expect(result?.resources).toHaveLength(1);
expect(list).toHaveBeenCalled();
Expand Down
204 changes: 198 additions & 6 deletions src/server/sse.test.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,146 @@
import http from 'http';
import { jest } from '@jest/globals';
import { SSEServerTransport } from './sse.js';
import { McpServer } from './mcp.js';
import { createServer, type Server } from "node:http";
import { AddressInfo } from "node:net";
import { z } from 'zod';
import { CallToolResult, JSONRPCMessage } from 'src/types.js';

const createMockResponse = () => {
const res = {
writeHead: jest.fn<http.ServerResponse['writeHead']>(),
write: jest.fn<http.ServerResponse['write']>().mockReturnValue(true),
on: jest.fn<http.ServerResponse['on']>(),
writeHead: jest.fn<http.ServerResponse['writeHead']>().mockReturnThis(),
write: jest.fn<http.ServerResponse['write']>().mockReturnThis(),
on: jest.fn<http.ServerResponse['on']>().mockReturnThis(),
end: jest.fn<http.ServerResponse['end']>().mockReturnThis(),
};
res.writeHead.mockReturnThis();
res.on.mockReturnThis();

return res as unknown as http.ServerResponse;
return res as unknown as jest.Mocked<http.ServerResponse>;
};

/**
* Helper to create and start test HTTP server with MCP setup
*/
async function createTestServerWithSse(args: {
mockRes: http.ServerResponse;
}): Promise<{
server: Server;
transport: SSEServerTransport;
mcpServer: McpServer;
baseUrl: URL;
sessionId: string
serverPort: number;
}> {
const mcpServer = new McpServer(
{ name: "test-server", version: "1.0.0" },
{ capabilities: { logging: {} } }
);

mcpServer.tool(
"greet",
"A simple greeting tool",
{ name: z.string().describe("Name to greet") },
async ({ name }): Promise<CallToolResult> => {
return { content: [{ type: "text", text: `Hello, ${name}!` }] };
}
);

const endpoint = '/messages';

const transport = new SSEServerTransport(endpoint, args.mockRes);
const sessionId = transport.sessionId;

await mcpServer.connect(transport);

const server = createServer(async (req, res) => {
try {
await transport.handlePostMessage(req, res);
} catch (error) {
console.error("Error handling request:", error);
if (!res.headersSent) res.writeHead(500).end();
}
});

const baseUrl = await new Promise<URL>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
resolve(new URL(`http://127.0.0.1:${addr.port}`));
});
});

const port = (server.address() as AddressInfo).port;

return { server, transport, mcpServer, baseUrl, sessionId, serverPort: port };
}

async function readAllSSEEvents(response: Response): Promise<string[]> {
const reader = response.body?.getReader();
if (!reader) throw new Error('No readable stream');

const events: string[] = [];
const decoder = new TextDecoder();

try {
while (true) {
const { done, value } = await reader.read();
if (done) break;

if (value) {
events.push(decoder.decode(value));
}
}
} finally {
reader.releaseLock();
}

return events;
}

/**
* Helper to send JSON-RPC request
*/
async function sendSsePostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record<string, string>): Promise<Response> {
const headers: Record<string, string> = {
"Content-Type": "application/json",
Accept: "application/json, text/event-stream",
...extraHeaders
};

if (sessionId) {
baseUrl.searchParams.set('sessionId', sessionId);
}

return fetch(baseUrl, {
method: "POST",
headers,
body: JSON.stringify(message),
});
}

describe('SSEServerTransport', () => {

async function initializeServer(baseUrl: URL): Promise<void> {
const response = await sendSsePostRequest(baseUrl, {
jsonrpc: "2.0",
method: "initialize",
params: {
clientInfo: { name: "test-client", version: "1.0" },
protocolVersion: "2025-03-26",
capabilities: {
},
},

id: "init-1",
} as JSONRPCMessage);

expect(response.status).toBe(202);

const text = await readAllSSEEvents(response);

expect(text).toHaveLength(1);
expect(text[0]).toBe('Accepted');
}

describe('start method', () => {
it('should correctly append sessionId to a simple relative endpoint', async () => {
const mockRes = createMockResponse();
Expand Down Expand Up @@ -105,5 +231,71 @@ describe('SSEServerTransport', () => {
`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`
);
});

/***
* Test: Tool With Request Info
*/
it("should pass request info to tool callback", async () => {
const mockRes = createMockResponse();
const { mcpServer, baseUrl, sessionId, serverPort } = await createTestServerWithSse({ mockRes });
await initializeServer(baseUrl);

mcpServer.tool(
"test-request-info",
"A simple test tool with request info",
{ name: z.string().describe("Name to greet") },
async ({ name }, { requestInfo }): Promise<CallToolResult> => {
return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] };
}
);

const toolCallMessage: JSONRPCMessage = {
jsonrpc: "2.0",
method: "tools/call",
params: {
name: "test-request-info",
arguments: {
name: "Test User",
},
},
id: "call-1",
};

const response = await sendSsePostRequest(baseUrl, toolCallMessage, sessionId);

expect(response.status).toBe(202);

expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${sessionId}\n\n`);

const expectedMessage = {
result: {
content: [
{
type: "text",
text: "Hello, Test User!",
},
{
type: "text",
text: JSON.stringify({
headers: {
host: `127.0.0.1:${serverPort}`,
connection: 'keep-alive',
'content-type': 'application/json',
accept: 'application/json, text/event-stream',
'accept-language': '*',
'sec-fetch-mode': 'cors',
'user-agent': 'node',
'accept-encoding': 'gzip, deflate',
'content-length': '124'
},
})
},
],
},
jsonrpc: "2.0",
id: "call-1",
};
expect(mockRes.write).toHaveBeenCalledWith(`event: message\ndata: ${JSON.stringify(expectedMessage)}\n\n`);
});
});
});
8 changes: 5 additions & 3 deletions src/server/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import getRawBody from "raw-body";
import contentType from "content-type";
import { AuthInfo } from "./auth/types.js";
import { MessageExtraInfo, RequestInfo } from "./types/types.js";
import { URL } from 'url';

const MAXIMUM_MESSAGE_SIZE = "4mb";
Expand All @@ -19,7 +20,7 @@ export class SSEServerTransport implements Transport {
private _sessionId: string;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
onmessage?: (message: JSONRPCMessage, extra: { authInfo?: AuthInfo, requestInfo: RequestInfo }) => void;

/**
* Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`.
Expand Down Expand Up @@ -86,6 +87,7 @@ export class SSEServerTransport implements Transport {
throw new Error(message);
}
const authInfo: AuthInfo | undefined = req.auth;
const requestInfo: RequestInfo = { headers: req.headers };

let body: string | unknown;
try {
Expand All @@ -105,7 +107,7 @@ export class SSEServerTransport implements Transport {
}

try {
await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { authInfo });
await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { requestInfo, authInfo });
} catch {
res.writeHead(400).end(`Invalid message: ${body}`);
return;
Expand All @@ -117,7 +119,7 @@ export class SSEServerTransport implements Transport {
/**
* Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST.
*/
async handleMessage(message: unknown, extra?: { authInfo?: AuthInfo }): Promise<void> {
async handleMessage(message: unknown, extra: MessageExtraInfo): Promise<void> {
let parsedMessage: JSONRPCMessage;
try {
parsedMessage = JSONRPCMessageSchema.parse(message);
Expand Down
Loading