Skip to content

feat: Add custom context support for MCP transports #819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions src/inMemory.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Transport } from "./shared/transport.js";
import { JSONRPCMessage, RequestId } from "./types.js";
import { JSONRPCMessage, RequestId, MessageExtraInfo } from "./types.js";
import { AuthInfo } from "./server/auth/types.js";

interface QueuedMessage {
message: JSONRPCMessage;
extra?: { authInfo?: AuthInfo };
extra?: MessageExtraInfo;
}

/**
Expand All @@ -13,10 +13,11 @@ interface QueuedMessage {
export class InMemoryTransport implements Transport {
private _otherTransport?: InMemoryTransport;
private _messageQueue: QueuedMessage[] = [];
private _customContext?: Record<string, unknown>;

onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
sessionId?: string;

/**
Expand All @@ -34,7 +35,12 @@ export class InMemoryTransport implements Transport {
// Process any messages that were queued before start was called
while (this._messageQueue.length > 0) {
const queuedMessage = this._messageQueue.shift()!;
this.onmessage?.(queuedMessage.message, queuedMessage.extra);
// Merge custom context with queued extra info
const enhancedExtra: MessageExtraInfo = {
...queuedMessage.extra,
customContext: this._customContext
};
this.onmessage?.(queuedMessage.message, enhancedExtra);
}
}

Expand All @@ -46,18 +52,45 @@ export class InMemoryTransport implements Transport {
}

/**
* Sends a message with optional auth info.
* This is useful for testing authentication scenarios.
* Sends a message with optional extra info.
* This is useful for testing authentication scenarios and custom context.
*
* @deprecated The authInfo parameter is deprecated. Use MessageExtraInfo instead.
*/
async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId, authInfo?: AuthInfo }): Promise<void> {
async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId, authInfo?: AuthInfo } | MessageExtraInfo): Promise<void> {
if (!this._otherTransport) {
throw new Error("Not connected");
}

// Handle both old and new API formats
let extra: MessageExtraInfo | undefined;
if (options && 'authInfo' in options && !('requestInfo' in options)) {
// Old API format - convert to new format
extra = { authInfo: options.authInfo };
} else if (options && ('requestInfo' in options || 'customContext' in options || 'authInfo' in options)) {
// New API format
extra = options as MessageExtraInfo;
} else if (options && 'authInfo' in options) {
// Old API with authInfo
extra = { authInfo: options.authInfo };
}

if (this._otherTransport.onmessage) {
this._otherTransport.onmessage(message, { authInfo: options?.authInfo });
// Merge the other transport's custom context with the extra info
const enhancedExtra: MessageExtraInfo = {
...extra,
customContext: this._otherTransport._customContext
};
this._otherTransport.onmessage(message, enhancedExtra);
} else {
this._otherTransport._messageQueue.push({ message, extra: { authInfo: options?.authInfo } });
this._otherTransport._messageQueue.push({ message, extra });
}
}

/**
* Sets custom context data that will be passed to all message handlers.
*/
setCustomContext(context: Record<string, unknown>): void {
this._customContext = context;
}
}
58 changes: 58 additions & 0 deletions src/server/mcp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,64 @@ describe("tool()", () => {
expect(result.content && result.content[0].text).toContain("Received request ID:");
});

/***
* Test: Pass Custom Context to Tool Callback
*/
test("should pass customContext to tool callback via RequestHandlerExtra", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});

const client = new Client({
name: "test client",
version: "1.0",
});

let receivedCustomContext: Record<string, unknown> | undefined;
mcpServer.tool("custom-context-test", async (extra) => {
receivedCustomContext = extra.customContext;
return {
content: [
{
type: "text",
text: `Custom context: ${JSON.stringify(extra.customContext)}`,
},
],
};
});

const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();

// Use the new setCustomContext method to inject custom context
serverTransport.setCustomContext({
tenantId: "test-tenant-123",
featureFlags: { newFeature: true },
customData: "test-value"
});

await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);

const result = await client.request(
{
method: "tools/call",
params: {
name: "custom-context-test",
},
},
CallToolResultSchema,
);

expect(receivedCustomContext).toBeDefined();
expect(receivedCustomContext?.tenantId).toBe("test-tenant-123");
expect(receivedCustomContext?.featureFlags).toEqual({ newFeature: true });
expect(receivedCustomContext?.customData).toBe("test-value");
expect(result.content && result.content[0].text).toContain("test-tenant-123");
});

/***
* Test: Send Notification within Tool Call
*/
Expand Down
15 changes: 14 additions & 1 deletion src/server/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export class SSEServerTransport implements Transport {
private _sseResponse?: ServerResponse;
private _sessionId: string;
private _options: SSEServerTransportOptions;
private _customContext?: Record<string, unknown>;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
Expand Down Expand Up @@ -191,7 +192,12 @@ export class SSEServerTransport implements Transport {
throw error;
}

this.onmessage?.(parsedMessage, extra);
// Merge custom context with the extra info
const enhancedExtra: MessageExtraInfo = {
...extra,
customContext: this._customContext
};
this.onmessage?.(parsedMessage, enhancedExtra);
}

async close(): Promise<void> {
Expand All @@ -218,4 +224,11 @@ export class SSEServerTransport implements Transport {
get sessionId(): string {
return this._sessionId;
}

/**
* Sets custom context data that will be passed to all message handlers.
*/
setCustomContext(context: Record<string, unknown>): void {
this._customContext = context;
}
}
18 changes: 15 additions & 3 deletions src/server/stdio.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import process from "node:process";
import { Readable, Writable } from "node:stream";
import { ReadBuffer, serializeMessage } from "../shared/stdio.js";
import { JSONRPCMessage } from "../types.js";
import { JSONRPCMessage, MessageExtraInfo } from "../types.js";
import { Transport } from "../shared/transport.js";

/**
Expand All @@ -12,6 +12,7 @@ import { Transport } from "../shared/transport.js";
export class StdioServerTransport implements Transport {
private _readBuffer: ReadBuffer = new ReadBuffer();
private _started = false;
private _customContext?: Record<string, unknown>;

constructor(
private _stdin: Readable = process.stdin,
Expand All @@ -20,7 +21,7 @@ export class StdioServerTransport implements Transport {

onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;

// Arrow functions to bind `this` properly, while maintaining function identity.
_ondata = (chunk: Buffer) => {
Expand Down Expand Up @@ -54,7 +55,11 @@ export class StdioServerTransport implements Transport {
break;
}

this.onmessage?.(message);
// Pass custom context to message handlers
const extra: MessageExtraInfo = {
customContext: this._customContext
};
this.onmessage?.(message, extra);
} catch (error) {
this.onerror?.(error as Error);
}
Expand Down Expand Up @@ -89,4 +94,11 @@ export class StdioServerTransport implements Transport {
}
});
}

/**
* Sets custom context data that will be passed to all message handlers.
*/
setCustomContext(context: Record<string, unknown>): void {
this._customContext = context;
}
}
22 changes: 20 additions & 2 deletions src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ export class StreamableHTTPServerTransport implements Transport {
private _allowedHosts?: string[];
private _allowedOrigins?: string[];
private _enableDnsRebindingProtection: boolean;
private _customContext?: Record<string, unknown>;

sessionId?: string;
onclose?: () => void;
Expand Down Expand Up @@ -487,7 +488,12 @@ export class StreamableHTTPServerTransport implements Transport {

// handle each message
for (const message of messages) {
this.onmessage?.(message, { authInfo, requestInfo });
const enhancedExtra: MessageExtraInfo = {
authInfo,
requestInfo,
customContext: this._customContext
};
this.onmessage?.(message, enhancedExtra);
}
} else if (hasRequests) {
// The default behavior is to use SSE streaming
Expand Down Expand Up @@ -522,7 +528,12 @@ export class StreamableHTTPServerTransport implements Transport {

// handle each message
for (const message of messages) {
this.onmessage?.(message, { authInfo, requestInfo });
const enhancedExtra: MessageExtraInfo = {
authInfo,
requestInfo,
customContext: this._customContext
};
this.onmessage?.(message, enhancedExtra);
}
// The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses
// This will be handled by the send() method when responses are ready
Expand Down Expand Up @@ -748,5 +759,12 @@ export class StreamableHTTPServerTransport implements Transport {
}
}
}

/**
* Sets custom context data that will be passed to all message handlers.
*/
setCustomContext(context: Record<string, unknown>): void {
this._customContext = context;
}
}

10 changes: 9 additions & 1 deletion src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ export type RequestHandlerExtra<SendRequestT extends Request,
*/
requestInfo?: RequestInfo;

/**
* Custom context data passed from the transport layer.
* This allows transport implementations to attach arbitrary data that will be
* available to request handlers.
*/
customContext?: Record<string, unknown>;

/**
* Sends a notification that relates to the current request being handled.
*
Expand Down Expand Up @@ -405,7 +412,8 @@ export abstract class Protocol<
this.request(r, resultSchema, { ...options, relatedRequestId: request.id }),
authInfo: extra?.authInfo,
requestId: request.id,
requestInfo: extra?.requestInfo
requestInfo: extra?.requestInfo,
customContext: extra?.customContext
};

// Starting with Promise.resolve() puts any synchronous errors into the monad as well.
Expand Down
6 changes: 6 additions & 0 deletions src/shared/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,10 @@ export interface Transport {
* Sets the protocol version used for the connection (called when the initialize response is received).
*/
setProtocolVersion?: (version: string) => void;

/**
* Sets custom context data that will be passed to all message handlers.
* This context will be included in the MessageExtraInfo passed to handlers.
*/
setCustomContext?: (context: Record<string, unknown>) => void;
}
7 changes: 7 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1512,6 +1512,13 @@ export interface MessageExtraInfo {
* The authentication information.
*/
authInfo?: AuthInfo;

/**
* Custom context data that can be passed through the message handling pipeline.
* This allows transport implementations to attach arbitrary data that will be
* available to request handlers.
*/
customContext?: Record<string, unknown>;
}

/* JSON-RPC types */
Expand Down