diff --git a/src/client/index.ts b/src/client/index.ts index 98618a171..a8fbdcee8 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -165,6 +165,10 @@ export class Client< this._serverCapabilities = result.capabilities; this._serverVersion = result.serverInfo; + // HTTP transports must set the protocol version in each header after initialization. + if (transport.setProtocolVersion) { + transport.setProtocolVersion(result.protocolVersion); + } this._instructions = result.instructions; diff --git a/src/client/sse.ts b/src/client/sse.ts index 7939e8cb5..5aa99abb4 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -62,6 +62,7 @@ export class SSEClientTransport implements Transport { private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _protocolVersion?: string; onclose?: () => void; onerror?: (error: Error) => void; @@ -99,13 +100,18 @@ export class SSEClientTransport implements Transport { } private async _commonHeaders(): Promise { - const headers: HeadersInit = { ...this._requestInit?.headers }; + const headers = { + ...this._requestInit?.headers, + } as HeadersInit & Record; if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { - (headers as Record)["Authorization"] = `Bearer ${tokens.access_token}`; + headers["Authorization"] = `Bearer ${tokens.access_token}`; } } + if (this._protocolVersion) { + headers["mcp-protocol-version"] = this._protocolVersion; + } return headers; } @@ -214,7 +220,7 @@ export class SSEClientTransport implements Transport { try { const commonHeaders = await this._commonHeaders(); - const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers }); + const headers = new Headers(commonHeaders); headers.set("content-type", "application/json"); const init = { ...this._requestInit, @@ -249,4 +255,8 @@ export class SSEClientTransport implements Transport { throw error; } } + + setProtocolVersion(version: string): void { + this._protocolVersion = version; + } } diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 1bcfbb2d1..4117bb1b4 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -124,6 +124,7 @@ export class StreamableHTTPClientTransport implements Transport { private _authProvider?: OAuthClientProvider; private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; + private _protocolVersion?: string; onclose?: () => void; onerror?: (error: Error) => void; @@ -162,7 +163,7 @@ export class StreamableHTTPClientTransport implements Transport { } private async _commonHeaders(): Promise { - const headers: HeadersInit = {}; + const headers: HeadersInit & Record = {}; if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { @@ -173,6 +174,9 @@ export class StreamableHTTPClientTransport implements Transport { if (this._sessionId) { headers["mcp-session-id"] = this._sessionId; } + if (this._protocolVersion) { + headers["mcp-protocol-version"] = this._protocolVersion; + } return new Headers( { ...headers, ...this._requestInit?.headers } @@ -516,4 +520,11 @@ export class StreamableHTTPClientTransport implements Transport { throw error; } } + + setProtocolVersion(version: string): void { + this._protocolVersion = version; + } + get protocolVersion(): string | undefined { + return this._protocolVersion; + } } diff --git a/src/integration-tests/stateManagementStreamableHttp.test.ts b/src/integration-tests/stateManagementStreamableHttp.test.ts index b7ff17e68..d12a4f993 100644 --- a/src/integration-tests/stateManagementStreamableHttp.test.ts +++ b/src/integration-tests/stateManagementStreamableHttp.test.ts @@ -5,7 +5,7 @@ import { Client } from '../client/index.js'; import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; import { McpServer } from '../server/mcp.js'; import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; -import { CallToolResultSchema, ListToolsResultSchema, ListResourcesResultSchema, ListPromptsResultSchema } from '../types.js'; +import { CallToolResultSchema, ListToolsResultSchema, ListResourcesResultSchema, ListPromptsResultSchema, DEFAULT_NEGOTIATED_PROTOCOL_VERSION } from '../types.js'; import { z } from 'zod'; describe('Streamable HTTP Transport Session Management', () => { @@ -211,6 +211,27 @@ describe('Streamable HTTP Transport Session Management', () => { // Clean up await transport.close(); }); + + it('should set protocol version after connecting', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + + // Verify protocol version is not set before connecting + expect(transport.protocolVersion).toBeUndefined(); + + await client.connect(transport); + + // Verify protocol version is set after connecting + expect(transport.protocolVersion).toBe(DEFAULT_NEGOTIATED_PROTOCOL_VERSION); + + // Clean up + await transport.close(); + }); }); describe('Stateful Mode', () => { diff --git a/src/server/index.ts b/src/server/index.ts index 3901099e3..caf72f9c3 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -251,10 +251,12 @@ export class Server< this._clientCapabilities = request.params.capabilities; this._clientVersion = request.params.clientInfo; - return { - protocolVersion: SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) + const protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) ? requestedVersion - : LATEST_PROTOCOL_VERSION, + : LATEST_PROTOCOL_VERSION; + + return { + protocolVersion, capabilities: this.getCapabilities(), serverInfo: this._serverInfo, ...(this._instructions && { instructions: this._instructions }), diff --git a/src/server/sse.ts b/src/server/sse.ts index 03f6fefc9..e9a4d53ab 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -17,7 +17,6 @@ const MAXIMUM_MESSAGE_SIZE = "4mb"; export class SSEServerTransport implements Transport { private _sseResponse?: ServerResponse; private _sessionId: string; - onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index b961f6c41..d66083fe8 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -185,6 +185,8 @@ async function sendPostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMe if (sessionId) { headers["mcp-session-id"] = sessionId; + // After initialization, include the protocol version header + headers["mcp-protocol-version"] = "2025-03-26"; } return fetch(baseUrl, { @@ -277,7 +279,7 @@ describe("StreamableHTTPServerTransport", () => { expectErrorResponse(errorData, -32600, /Only one initialization request is allowed/); }); - it("should pandle post requests via sse response correctly", async () => { + it("should handle post requests via sse response correctly", async () => { sessionId = await initializeServer(); const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); @@ -376,6 +378,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { Accept: "text/event-stream", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", }, }); @@ -417,6 +420,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { Accept: "text/event-stream", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", }, }); @@ -448,6 +452,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { Accept: "text/event-stream", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", }, }); @@ -459,6 +464,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { Accept: "text/event-stream", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", }, }); @@ -477,6 +483,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { Accept: "application/json", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", }, }); @@ -670,6 +677,7 @@ describe("StreamableHTTPServerTransport", () => { headers: { Accept: "text/event-stream", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", }, }); @@ -705,7 +713,10 @@ describe("StreamableHTTPServerTransport", () => { // Now DELETE the session const deleteResponse = await fetch(tempUrl, { method: "DELETE", - headers: { "mcp-session-id": tempSessionId || "" }, + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, }); expect(deleteResponse.status).toBe(200); @@ -721,13 +732,124 @@ describe("StreamableHTTPServerTransport", () => { // Try to delete with invalid session ID const response = await fetch(baseUrl, { method: "DELETE", - headers: { "mcp-session-id": "invalid-session-id" }, + headers: { + "mcp-session-id": "invalid-session-id", + "mcp-protocol-version": "2025-03-26", + }, }); expect(response.status).toBe(404); const errorData = await response.json(); expectErrorResponse(errorData, -32001, /Session not found/); }); + + describe("protocol version header validation", () => { + it("should accept requests with matching protocol version", async () => { + sessionId = await initializeServer(); + + // Send request with matching protocol version + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); + + expect(response.status).toBe(200); + }); + + it("should accept requests without protocol version header", async () => { + sessionId = await initializeServer(); + + // Send request without protocol version header + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + "mcp-session-id": sessionId, + // No mcp-protocol-version header + }, + body: JSON.stringify(TEST_MESSAGES.toolsList), + }); + + expect(response.status).toBe(200); + }); + + it("should reject requests with unsupported protocol version", async () => { + sessionId = await initializeServer(); + + // Send request with unsupported protocol version + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + "mcp-session-id": sessionId, + "mcp-protocol-version": "1999-01-01", // Unsupported version + }, + body: JSON.stringify(TEST_MESSAGES.toolsList), + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + + it("should accept when protocol version differs from negotiated version", async () => { + sessionId = await initializeServer(); + + // Spy on console.warn to verify warning is logged + const warnSpy = jest.spyOn(console, 'warn').mockImplementation(); + + // Send request with different but supported protocol version + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + "mcp-session-id": sessionId, + "mcp-protocol-version": "2024-11-05", // Different but supported version + }, + body: JSON.stringify(TEST_MESSAGES.toolsList), + }); + + // Request should still succeed + expect(response.status).toBe(200); + + warnSpy.mockRestore(); + }); + + it("should handle protocol version validation for GET requests", async () => { + sessionId = await initializeServer(); + + // GET request with unsupported protocol version + const response = await fetch(baseUrl, { + method: "GET", + headers: { + Accept: "text/event-stream", + "mcp-session-id": sessionId, + "mcp-protocol-version": "invalid-version", + }, + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + + it("should handle protocol version validation for DELETE requests", async () => { + sessionId = await initializeServer(); + + // DELETE request with unsupported protocol version + const response = await fetch(baseUrl, { + method: "DELETE", + headers: { + "mcp-session-id": sessionId, + "mcp-protocol-version": "invalid-version", + }, + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + }); }); describe("StreamableHTTPServerTransport with AuthInfo", () => { @@ -1120,6 +1242,7 @@ describe("StreamableHTTPServerTransport with resumability", () => { headers: { Accept: "text/event-stream", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", }, }); @@ -1196,6 +1319,7 @@ describe("StreamableHTTPServerTransport with resumability", () => { headers: { Accept: "text/event-stream", "mcp-session-id": sessionId, + "mcp-protocol-version": "2025-03-26", "last-event-id": firstEventId }, }); @@ -1282,14 +1406,20 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { // Open first SSE stream const stream1 = await fetch(baseUrl, { method: "GET", - headers: { Accept: "text/event-stream" }, + headers: { + Accept: "text/event-stream", + "mcp-protocol-version": "2025-03-26" + }, }); expect(stream1.status).toBe(200); // Open second SSE stream - should still be rejected, stateless mode still only allows one const stream2 = await fetch(baseUrl, { method: "GET", - headers: { Accept: "text/event-stream" }, + headers: { + Accept: "text/event-stream", + "mcp-protocol-version": "2025-03-26" + }, }); expect(stream2.status).toBe(409); // Conflict - only one stream allowed }); diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index dc99c3065..34b2ab68a 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -1,6 +1,6 @@ import { IncomingMessage, ServerResponse } from "node:http"; import { Transport } from "../shared/transport.js"; -import { isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId } from "../types.js"; +import { isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId, SUPPORTED_PROTOCOL_VERSIONS, DEFAULT_NEGOTIATED_PROTOCOL_VERSION } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; import { randomUUID } from "node:crypto"; @@ -110,7 +110,7 @@ export class StreamableHTTPServerTransport implements Transport { private _eventStore?: EventStore; private _onsessioninitialized?: (sessionId: string) => void; - sessionId?: string | undefined; + sessionId?: string; onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; @@ -172,6 +172,9 @@ export class StreamableHTTPServerTransport implements Transport { if (!this.validateSession(req, res)) { return; } + if (!this.validateProtocolVersion(req, res)) { + return; + } // Handle resumability: check for Last-Event-ID header if (this._eventStore) { const lastEventId = req.headers['last-event-id'] as string | undefined; @@ -378,11 +381,17 @@ export class StreamableHTTPServerTransport implements Transport { } } - // If an Mcp-Session-Id is returned by the server during initialization, - // clients using the Streamable HTTP transport MUST include it - // in the Mcp-Session-Id header on all of their subsequent HTTP requests. - if (!isInitializationRequest && !this.validateSession(req, res)) { - return; + if (!isInitializationRequest) { + // If an Mcp-Session-Id is returned by the server during initialization, + // clients using the Streamable HTTP transport MUST include it + // in the Mcp-Session-Id header on all of their subsequent HTTP requests. + if (!this.validateSession(req, res)) { + return; + } + // Mcp-Protocol-Version header is required for all requests after initialization. + if (!this.validateProtocolVersion(req, res)) { + return; + } } @@ -457,6 +466,9 @@ export class StreamableHTTPServerTransport implements Transport { if (!this.validateSession(req, res)) { return; } + if (!this.validateProtocolVersion(req, res)) { + return; + } await this.close(); res.writeHead(200).end(); } @@ -524,6 +536,25 @@ export class StreamableHTTPServerTransport implements Transport { return true; } + private validateProtocolVersion(req: IncomingMessage, res: ServerResponse): boolean { + let protocolVersion = req.headers["mcp-protocol-version"] ?? DEFAULT_NEGOTIATED_PROTOCOL_VERSION; + if (Array.isArray(protocolVersion)) { + protocolVersion = protocolVersion[protocolVersion.length - 1]; + } + + if (!SUPPORTED_PROTOCOL_VERSIONS.includes(protocolVersion)) { + res.writeHead(400).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: `Bad Request: Unsupported protocol version (supported versions: ${SUPPORTED_PROTOCOL_VERSIONS.join(", ")})` + }, + id: null + })); + return false; + } + return true; + } async close(): Promise { // Close all SSE connections diff --git a/src/shared/transport.ts b/src/shared/transport.ts index fe0a60e6d..b75e072e8 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -75,4 +75,9 @@ export interface Transport { * The session ID generated for this connection. */ sessionId?: string; + + /** + * Sets the protocol version used for the connection (called when the initialize response is received). + */ + setProtocolVersion?: (version: string) => void; } diff --git a/src/types.ts b/src/types.ts index ae25848ea..4bf87ef65 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,6 +1,7 @@ import { z, ZodTypeAny } from "zod"; export const LATEST_PROTOCOL_VERSION = "2025-03-26"; +export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = "2025-03-26"; export const SUPPORTED_PROTOCOL_VERSIONS = [ LATEST_PROTOCOL_VERSION, "2024-11-05",