Skip to content
Merged
4 changes: 4 additions & 0 deletions src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ export class Client<

this._serverCapabilities = result.capabilities;
this._serverVersion = result.serverInfo;
// HTTP transports must set the protocol version in each header after initialization.
if (transport.setProtocolVersion) {
transport.setProtocolVersion(result.protocolVersion);
}

this._instructions = result.instructions;

Expand Down
16 changes: 13 additions & 3 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export class SSEClientTransport implements Transport {
private _eventSourceInit?: EventSourceInit;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
private _protocolVersion?: string;

onclose?: () => void;
onerror?: (error: Error) => void;
Expand Down Expand Up @@ -99,13 +100,18 @@ export class SSEClientTransport implements Transport {
}

private async _commonHeaders(): Promise<HeadersInit> {
const headers: HeadersInit = { ...this._requestInit?.headers };
const headers = {
...this._requestInit?.headers,
} as HeadersInit & Record<string, string>;
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
if (tokens) {
(headers as Record<string, string>)["Authorization"] = `Bearer ${tokens.access_token}`;
headers["Authorization"] = `Bearer ${tokens.access_token}`;
}
}
if (this._protocolVersion) {
headers["mcp-protocol-version"] = this._protocolVersion;
}

return headers;
}
Expand Down Expand Up @@ -214,7 +220,7 @@ export class SSEClientTransport implements Transport {

try {
const commonHeaders = await this._commonHeaders();
const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers });
const headers = new Headers(commonHeaders);
headers.set("content-type", "application/json");
const init = {
...this._requestInit,
Expand Down Expand Up @@ -249,4 +255,8 @@ export class SSEClientTransport implements Transport {
throw error;
}
}

setProtocolVersion(version: string): void {
this._protocolVersion = version;
}
}
13 changes: 12 additions & 1 deletion src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ export class StreamableHTTPClientTransport implements Transport {
private _authProvider?: OAuthClientProvider;
private _sessionId?: string;
private _reconnectionOptions: StreamableHTTPReconnectionOptions;
private _protocolVersion?: string;

onclose?: () => void;
onerror?: (error: Error) => void;
Expand Down Expand Up @@ -162,7 +163,7 @@ export class StreamableHTTPClientTransport implements Transport {
}

private async _commonHeaders(): Promise<Headers> {
const headers: HeadersInit = {};
const headers: HeadersInit & Record<string, string> = {};
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
if (tokens) {
Expand All @@ -173,6 +174,9 @@ export class StreamableHTTPClientTransport implements Transport {
if (this._sessionId) {
headers["mcp-session-id"] = this._sessionId;
}
if (this._protocolVersion) {
headers["mcp-protocol-version"] = this._protocolVersion;
}

return new Headers(
{ ...headers, ...this._requestInit?.headers }
Expand Down Expand Up @@ -516,4 +520,11 @@ export class StreamableHTTPClientTransport implements Transport {
throw error;
}
}

setProtocolVersion(version: string): void {
this._protocolVersion = version;
}
get protocolVersion(): string | undefined {
return this._protocolVersion;
}
}
23 changes: 22 additions & 1 deletion src/integration-tests/stateManagementStreamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Client } from '../client/index.js';
import { StreamableHTTPClientTransport } from '../client/streamableHttp.js';
import { McpServer } from '../server/mcp.js';
import { StreamableHTTPServerTransport } from '../server/streamableHttp.js';
import { CallToolResultSchema, ListToolsResultSchema, ListResourcesResultSchema, ListPromptsResultSchema } from '../types.js';
import { CallToolResultSchema, ListToolsResultSchema, ListResourcesResultSchema, ListPromptsResultSchema, DEFAULT_NEGOTIATED_PROTOCOL_VERSION } from '../types.js';
import { z } from 'zod';

describe('Streamable HTTP Transport Session Management', () => {
Expand Down Expand Up @@ -211,6 +211,27 @@ describe('Streamable HTTP Transport Session Management', () => {
// Clean up
await transport.close();
});

it('should set protocol version after connecting', async () => {
// Create and connect a client
const client = new Client({
name: 'test-client',
version: '1.0.0'
});

const transport = new StreamableHTTPClientTransport(baseUrl);

// Verify protocol version is not set before connecting
expect(transport.protocolVersion).toBeUndefined();

await client.connect(transport);

// Verify protocol version is set after connecting
expect(transport.protocolVersion).toBe(DEFAULT_NEGOTIATED_PROTOCOL_VERSION);

// Clean up
await transport.close();
});
});

describe('Stateful Mode', () => {
Expand Down
8 changes: 5 additions & 3 deletions src/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,12 @@ export class Server<
this._clientCapabilities = request.params.capabilities;
this._clientVersion = request.params.clientInfo;

return {
protocolVersion: SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion)
const protocolVersion = SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion)
? requestedVersion
: LATEST_PROTOCOL_VERSION,
: LATEST_PROTOCOL_VERSION;

return {
protocolVersion,
capabilities: this.getCapabilities(),
serverInfo: this._serverInfo,
...(this._instructions && { instructions: this._instructions }),
Expand Down
1 change: 0 additions & 1 deletion src/server/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ const MAXIMUM_MESSAGE_SIZE = "4mb";
export class SSEServerTransport implements Transport {
private _sseResponse?: ServerResponse;
private _sessionId: string;

onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
Expand Down
Loading