Skip to content

Commit 2be0d17

Browse files
committed
feat: add session management for streamableHttp [MCP-52]
1 parent b7c616a commit 2be0d17

File tree

6 files changed

+182
-148
lines changed

6 files changed

+182
-148
lines changed

.vscode/launch.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"name": "Launch Program",
2020
"skipFiles": ["<node_internals>/**"],
2121
"program": "${workspaceFolder}/dist/index.js",
22+
"args": ["--transport", "http", "--loggers", "stderr", "mcp"],
2223
"preLaunchTask": "tsc: build - tsconfig.build.json",
2324
"outFiles": ["${workspaceFolder}/dist/**/*.js"]
2425
}

src/common/logger.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,9 @@ export const LogId = {
4040
toolUpdateFailure: mongoLogId(1_005_001),
4141

4242
streamableHttpTransportStarted: mongoLogId(1_006_001),
43-
streamableHttpTransportStartFailure: mongoLogId(1_006_002),
44-
streamableHttpTransportSessionInitialized: mongoLogId(1_006_003),
45-
streamableHttpTransportRequestFailure: mongoLogId(1_006_004),
46-
streamableHttpTransportCloseRequested: mongoLogId(1_006_005),
47-
streamableHttpTransportCloseSuccess: mongoLogId(1_006_006),
48-
streamableHttpTransportCloseFailure: mongoLogId(1_006_007),
43+
streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002),
44+
streamableHttpTransportRequestFailure: mongoLogId(1_006_003),
45+
streamableHttpTransportCloseFailure: mongoLogId(1_006_004),
4946
} as const;
5047

5148
export abstract class LoggerBase {

src/common/sessionStore.ts

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
2+
import logger, { LogId } from "./logger.js";
3+
4+
export class SessionStore {
5+
private sessions: Map<string, StreamableHTTPServerTransport> = new Map();
6+
7+
getSession(sessionId: string): StreamableHTTPServerTransport | undefined {
8+
return this.sessions.get(sessionId);
9+
}
10+
11+
setSession(sessionId: string, transport: StreamableHTTPServerTransport): void {
12+
if (this.sessions.has(sessionId)) {
13+
throw new Error(`Session ${sessionId} already exists`);
14+
}
15+
this.sessions.set(sessionId, transport);
16+
}
17+
18+
async closeSession(sessionId: string, closeTransport: boolean = true): Promise<void> {
19+
if (!this.sessions.has(sessionId)) {
20+
throw new Error(`Session ${sessionId} not found`);
21+
}
22+
if (closeTransport) {
23+
const transport = this.sessions.get(sessionId);
24+
if (!transport) {
25+
throw new Error(`Session ${sessionId} not found`);
26+
}
27+
try {
28+
await transport.close();
29+
} catch (error) {
30+
logger.error(
31+
LogId.streamableHttpTransportSessionCloseFailure,
32+
"streamableHttpTransport",
33+
`Error closing transport ${sessionId}: ${error instanceof Error ? error.message : String(error)}`
34+
);
35+
}
36+
}
37+
this.sessions.delete(sessionId);
38+
}
39+
40+
async closeAllSessions(): Promise<void> {
41+
await Promise.all(Array.from(this.sessions.values()).map((transport) => transport.close()));
42+
this.sessions.clear();
43+
}
44+
}

src/transports/streamableHttp.ts

Lines changed: 89 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,132 @@
11
import express from "express";
22
import http from "http";
33
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
4+
import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js";
45
import { TransportRunnerBase } from "./base.js";
56
import { config } from "../common/config.js";
67
import logger, { LogId } from "../common/logger.js";
8+
import { randomUUID } from "crypto";
9+
import { SessionStore } from "../common/sessionStore.js";
710

811
const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000;
9-
const JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED = -32601;
12+
const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001;
13+
const JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND = -32002;
14+
const JSON_RPC_ERROR_CODE_INVALID_REQUEST = -32003;
1015

1116
function promiseHandler(
1217
fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise<void>
1318
) {
1419
return (req: express.Request, res: express.Response, next: express.NextFunction) => {
15-
fn(req, res, next).catch(next);
20+
fn(req, res, next).catch((error) => {
21+
logger.error(
22+
LogId.streamableHttpTransportRequestFailure,
23+
"streamableHttpTransport",
24+
`Error handling request: ${error instanceof Error ? error.message : String(error)}`
25+
);
26+
res.status(400).json({
27+
jsonrpc: "2.0",
28+
error: {
29+
code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED,
30+
message: `failed to handle request`,
31+
data: error instanceof Error ? error.message : String(error),
32+
},
33+
});
34+
});
1635
};
1736
}
1837

1938
export class StreamableHttpRunner extends TransportRunnerBase {
2039
private httpServer: http.Server | undefined;
40+
private sessionStore: SessionStore = new SessionStore();
2141

2242
async start() {
2343
const app = express();
2444
app.enable("trust proxy"); // needed for reverse proxy support
25-
app.use(express.urlencoded({ extended: true }));
2645
app.use(express.json());
2746

47+
const handleRequest = async (req: express.Request, res: express.Response) => {
48+
const sessionId = req.headers["mcp-session-id"] as string;
49+
if (!sessionId) {
50+
res.status(400).json({
51+
jsonrpc: "2.0",
52+
error: {
53+
code: JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED,
54+
message: `session id is required`,
55+
},
56+
});
57+
return;
58+
}
59+
const transport = this.sessionStore.getSession(sessionId);
60+
if (!transport) {
61+
res.status(404).json({
62+
jsonrpc: "2.0",
63+
error: {
64+
code: JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND,
65+
message: `session not found`,
66+
},
67+
});
68+
return;
69+
}
70+
await transport.handleRequest(req, res, req.body);
71+
};
72+
2873
app.post(
2974
"/mcp",
3075
promiseHandler(async (req: express.Request, res: express.Response) => {
31-
const transport = new StreamableHTTPServerTransport({
32-
sessionIdGenerator: undefined,
33-
});
76+
const sessionId = req.headers["mcp-session-id"] as string;
77+
if (sessionId) {
78+
await handleRequest(req, res);
79+
return;
80+
}
3481

35-
const server = this.setupServer();
82+
if (!isInitializeRequest(req.body)) {
83+
res.status(400).json({
84+
jsonrpc: "2.0",
85+
error: {
86+
code: JSON_RPC_ERROR_CODE_INVALID_REQUEST,
87+
message: `invalid request`,
88+
},
89+
});
90+
return;
91+
}
3692

37-
await server.connect(transport);
93+
const server = this.setupServer();
94+
const transport = new StreamableHTTPServerTransport({
95+
sessionIdGenerator: () => randomUUID().toString(),
96+
onsessioninitialized: (sessionId) => {
97+
this.sessionStore.setSession(sessionId, transport);
98+
},
99+
onsessionclosed: async (sessionId) => {
100+
try {
101+
await this.sessionStore.closeSession(sessionId, false);
102+
} catch (error) {
103+
logger.error(
104+
LogId.streamableHttpTransportSessionCloseFailure,
105+
"streamableHttpTransport",
106+
`Error closing session: ${error instanceof Error ? error.message : String(error)}`
107+
);
108+
}
109+
},
110+
});
38111

39-
res.on("close", () => {
40-
Promise.all([transport.close(), server.close()]).catch((error: unknown) => {
112+
transport.onclose = () => {
113+
server.close().catch((error) => {
41114
logger.error(
42115
LogId.streamableHttpTransportCloseFailure,
43116
"streamableHttpTransport",
44117
`Error closing server: ${error instanceof Error ? error.message : String(error)}`
45118
);
46119
});
47-
});
120+
};
48121

49-
try {
50-
await transport.handleRequest(req, res, req.body);
51-
} catch (error) {
52-
logger.error(
53-
LogId.streamableHttpTransportRequestFailure,
54-
"streamableHttpTransport",
55-
`Error handling request: ${error instanceof Error ? error.message : String(error)}`
56-
);
57-
res.status(400).json({
58-
jsonrpc: "2.0",
59-
error: {
60-
code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED,
61-
message: `failed to handle request`,
62-
data: error instanceof Error ? error.message : String(error),
63-
},
64-
});
65-
}
122+
await server.connect(transport);
123+
124+
await transport.handleRequest(req, res, req.body);
66125
})
67126
);
68127

69-
app.get("/mcp", (req: express.Request, res: express.Response) => {
70-
res.status(405).json({
71-
jsonrpc: "2.0",
72-
error: {
73-
code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED,
74-
message: `method not allowed`,
75-
},
76-
});
77-
});
78-
79-
app.delete("/mcp", (req: express.Request, res: express.Response) => {
80-
res.status(405).json({
81-
jsonrpc: "2.0",
82-
error: {
83-
code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED,
84-
message: `method not allowed`,
85-
},
86-
});
87-
});
128+
app.get("/mcp", promiseHandler(handleRequest));
129+
app.delete("/mcp", promiseHandler(handleRequest));
88130

89131
this.httpServer = await new Promise<http.Server>((resolve, reject) => {
90132
const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => {
Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,40 @@
1-
import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
21
import { describe, expect, it, beforeAll, afterAll } from "vitest";
32
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
3+
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
44

55
describe("StdioRunner", () => {
66
describe("client connects successfully", () => {
7-
let client: StdioClientTransport;
7+
let client: Client;
8+
let transport: StdioClientTransport;
89
beforeAll(async () => {
9-
client = new StdioClientTransport({
10+
transport = new StdioClientTransport({
1011
command: "node",
1112
args: ["dist/index.js"],
1213
env: {
1314
MDB_MCP_TRANSPORT: "stdio",
1415
},
1516
});
16-
await client.start();
17+
client = new Client({
18+
name: "test",
19+
version: "0.0.0",
20+
});
21+
await client.connect(transport);
1722
});
1823

1924
afterAll(async () => {
2025
await client.close();
26+
await transport.close();
2127
});
2228

2329
it("handles requests and sends responses", async () => {
24-
let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined;
25-
const messagePromise = new Promise<JSONRPCMessage>((resolve) => {
26-
fixedResolve = resolve;
27-
});
28-
29-
client.onmessage = (message: JSONRPCMessage) => {
30-
fixedResolve?.(message);
31-
};
32-
33-
await client.send({
34-
jsonrpc: "2.0",
35-
id: 1,
36-
method: "tools/list",
37-
params: {
38-
_meta: {
39-
progressToken: 1,
40-
},
41-
},
42-
});
43-
44-
const message = (await messagePromise) as {
45-
jsonrpc: string;
46-
id: number;
47-
result: {
48-
tools: {
49-
name: string;
50-
description: string;
51-
}[];
52-
};
53-
error?: {
54-
code: number;
55-
message: string;
56-
};
57-
};
30+
const response = await client.listTools();
31+
expect(response).toBeDefined();
32+
expect(response.tools).toBeDefined();
33+
expect(response.tools).toHaveLength(20);
5834

59-
expect(message.jsonrpc).toBe("2.0");
60-
expect(message.id).toBe(1);
61-
expect(message.result).toBeDefined();
62-
expect(message.result?.tools).toBeDefined();
63-
expect(message.result?.tools.length).toBeGreaterThan(0);
64-
const tools = message.result?.tools;
65-
tools.sort((a, b) => a.name.localeCompare(b.name));
66-
expect(tools[0]?.name).toBe("aggregate");
67-
expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection");
35+
const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name));
36+
expect(sortedTools[0]?.name).toBe("aggregate");
37+
expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection");
6838
});
6939
});
7040
});

0 commit comments

Comments
 (0)