Skip to content

Commit 5f1e1b2

Browse files
committed
add tests
1 parent 5249266 commit 5f1e1b2

File tree

3 files changed

+79
-34
lines changed

3 files changed

+79
-34
lines changed

src/common/config.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ export interface UserConfig extends CliOptions {
116116
transport: "stdio" | "http";
117117
httpPort: number;
118118
httpHost: string;
119-
httpHeaders?: Record<string, string>;
119+
httpHeaders: Record<string, string>;
120120
loggers: Array<"stderr" | "disk" | "mcp">;
121121
idleTimeoutMs: number;
122122
notificationTimeoutMs: number;
@@ -138,6 +138,7 @@ export const defaultUserConfig: UserConfig = {
138138
loggers: ["disk", "mcp"],
139139
idleTimeoutMs: 600000, // 10 minutes
140140
notificationTimeoutMs: 540000, // 9 minutes
141+
httpHeaders: {},
141142
};
142143

143144
export const config = setupUserConfig({

src/transports/streamableHttp.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,11 @@ export class StreamableHttpRunner extends TransportRunnerBase {
4545
app.enable("trust proxy"); // needed for reverse proxy support
4646
app.use(express.json());
4747
app.use((req, res, next) => {
48-
if (this.userConfig.httpHeaders) {
49-
for (const [key, value] of Object.entries(this.userConfig.httpHeaders)) {
50-
const header = req.headers[key];
51-
if (!header || header !== value) {
52-
res.sendStatus(403).json({ error: `Invalid ${key} header` });
53-
return;
54-
}
48+
for (const [key, value] of Object.entries(this.userConfig.httpHeaders)) {
49+
const header = req.headers[key];
50+
if (!header || header !== value) {
51+
res.sendStatus(403).json({ error: `Invalid ${key} header` });
52+
return;
5553
}
5654
}
5755

tests/integration/transports/streamableHttp.test.ts

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@ describe("StreamableHttpRunner", () => {
99
let oldTelemetry: "enabled" | "disabled";
1010
let oldLoggers: ("stderr" | "disk" | "mcp")[];
1111

12-
beforeAll(async () => {
12+
beforeAll(() => {
1313
oldTelemetry = config.telemetry;
1414
oldLoggers = config.loggers;
1515
config.telemetry = "disabled";
1616
config.loggers = ["stderr"];
1717
config.httpPort = 0; // Use a random port for testing
18-
runner = new StreamableHttpRunner(config);
19-
await runner.start();
2018
});
2119

2220
afterAll(async () => {
@@ -25,33 +23,81 @@ describe("StreamableHttpRunner", () => {
2523
config.loggers = oldLoggers;
2624
});
2725

28-
describe("client connects successfully", () => {
29-
let client: Client;
30-
let transport: StreamableHTTPClientTransport;
31-
beforeAll(async () => {
32-
transport = new StreamableHTTPClientTransport(new URL(`${runner.address}/mcp`));
26+
const headerTestCases: { headers: Record<string, string>; description: string }[] = [
27+
{ headers: {}, description: "without headers" },
28+
{ headers: { "x-custom-header": "test-value" }, description: "with headers" },
29+
];
3330

34-
client = new Client({
35-
name: "test",
36-
version: "0.0.0",
31+
for (const { headers, description } of headerTestCases) {
32+
describe(description, () => {
33+
beforeAll(async () => {
34+
config.httpHeaders = headers;
35+
runner = new StreamableHttpRunner(config);
36+
await runner.start();
3737
});
38-
await client.connect(transport);
39-
});
4038

41-
afterAll(async () => {
42-
await client.close();
43-
await transport.close();
44-
});
39+
const clientHeaderTestCases = [
40+
{
41+
headers: {},
42+
description: "without client headers",
43+
expectSuccess: Object.keys(headers).length === 0,
44+
},
45+
{ headers, description: "with matching client headers", expectSuccess: true },
46+
{ headers: { ...headers, foo: "bar" }, description: "with extra client headers", expectSuccess: true },
47+
{
48+
headers: { foo: "bar" },
49+
description: "with non-matching client headers",
50+
expectSuccess: Object.keys(headers).length === 0,
51+
},
52+
];
53+
54+
for (const {
55+
headers: clientHeaders,
56+
description: clientDescription,
57+
expectSuccess,
58+
} of clientHeaderTestCases) {
59+
describe(clientDescription, () => {
60+
let client: Client;
61+
let transport: StreamableHTTPClientTransport;
62+
beforeAll(() => {
63+
client = new Client({
64+
name: "test",
65+
version: "0.0.0",
66+
});
67+
transport = new StreamableHTTPClientTransport(new URL(`${runner.address}/mcp`), {
68+
requestInit: {
69+
headers: clientHeaders,
70+
},
71+
});
72+
});
4573

46-
it("handles requests and sends responses", async () => {
47-
const response = await client.listTools();
48-
expect(response).toBeDefined();
49-
expect(response.tools).toBeDefined();
50-
expect(response.tools.length).toBeGreaterThan(0);
74+
afterAll(async () => {
75+
await client.close();
76+
await transport.close();
77+
});
5178

52-
const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name));
53-
expect(sortedTools[0]?.name).toBe("aggregate");
54-
expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection");
79+
it(`should ${expectSuccess ? "succeed" : "fail"}`, async () => {
80+
try {
81+
await client.connect(transport);
82+
const response = await client.listTools();
83+
expect(response).toBeDefined();
84+
expect(response.tools).toBeDefined();
85+
expect(response.tools.length).toBeGreaterThan(0);
86+
87+
const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name));
88+
expect(sortedTools[0]?.name).toBe("aggregate");
89+
expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection");
90+
} catch (err) {
91+
if (expectSuccess) {
92+
throw err;
93+
} else {
94+
expect(err).toBeDefined();
95+
expect(err?.toString()).toContain("HTTP 403");
96+
}
97+
}
98+
});
99+
});
100+
}
55101
});
56-
});
102+
}
57103
});

0 commit comments

Comments
 (0)