Skip to content
2 changes: 2 additions & 0 deletions src/common/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export interface UserConfig extends CliOptions {
transport: "stdio" | "http";
httpPort: number;
httpHost: string;
httpHeaders: Record<string, string>;
loggers: Array<"stderr" | "disk" | "mcp">;
idleTimeoutMs: number;
notificationTimeoutMs: number;
Expand All @@ -137,6 +138,7 @@ export const defaultUserConfig: UserConfig = {
loggers: ["disk", "mcp"],
idleTimeoutMs: 600000, // 10 minutes
notificationTimeoutMs: 540000, // 9 minutes
httpHeaders: {},
};

export const config = setupUserConfig({
Expand Down
3 changes: 2 additions & 1 deletion src/lib.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export { Server, type ServerOptions } from "./server.js";
export { Telemetry } from "./telemetry/telemetry.js";
export { Session, type SessionOptions } from "./common/session.js";
export type { UserConfig } from "./common/config.js";
export { type UserConfig, defaultUserConfig } from "./common/config.js";
export { StreamableHttpRunner } from "./transports/streamableHttp.js";
23 changes: 23 additions & 0 deletions src/transports/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ export class StreamableHttpRunner extends TransportRunnerBase {
private httpServer: http.Server | undefined;
private sessionStore!: SessionStore;

public get address(): string {
const result = this.httpServer?.address();
if (typeof result === "string") {
return result;
}
if (typeof result === "object" && result) {
return `http://${result.address}:${result.port}`;
}

throw new Error("Server is not started yet");
}

constructor(userConfig: UserConfig) {
super(userConfig);
}
Expand All @@ -32,6 +44,17 @@ export class StreamableHttpRunner extends TransportRunnerBase {

app.enable("trust proxy"); // needed for reverse proxy support
app.use(express.json());
app.use((req, res, next) => {
for (const [key, value] of Object.entries(this.userConfig.httpHeaders)) {
const header = req.headers[key];
if (!header || header !== value) {
res.sendStatus(403).json({ error: `Invalid ${key} header` });
return;
}
}

next();
});

const handleSessionRequest = async (req: express.Request, res: express.Response): Promise<void> => {
const sessionId = req.headers["mcp-session-id"];
Expand Down
99 changes: 73 additions & 26 deletions tests/integration/transports/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ describe("StreamableHttpRunner", () => {
let oldTelemetry: "enabled" | "disabled";
let oldLoggers: ("stderr" | "disk" | "mcp")[];

beforeAll(async () => {
beforeAll(() => {
oldTelemetry = config.telemetry;
oldLoggers = config.loggers;
config.telemetry = "disabled";
config.loggers = ["stderr"];
runner = new StreamableHttpRunner(config);
await runner.start();
config.httpPort = 0; // Use a random port for testing
});

afterAll(async () => {
Expand All @@ -24,33 +23,81 @@ describe("StreamableHttpRunner", () => {
config.loggers = oldLoggers;
});

describe("client connects successfully", () => {
let client: Client;
let transport: StreamableHTTPClientTransport;
beforeAll(async () => {
transport = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp"));
const headerTestCases: { headers: Record<string, string>; description: string }[] = [
{ headers: {}, description: "without headers" },
{ headers: { "x-custom-header": "test-value" }, description: "with headers" },
];

client = new Client({
name: "test",
version: "0.0.0",
for (const { headers, description } of headerTestCases) {
describe(description, () => {
beforeAll(async () => {
config.httpHeaders = headers;
runner = new StreamableHttpRunner(config);
await runner.start();
});
await client.connect(transport);
});

afterAll(async () => {
await client.close();
await transport.close();
});
const clientHeaderTestCases = [
{
headers: {},
description: "without client headers",
expectSuccess: Object.keys(headers).length === 0,
},
{ headers, description: "with matching client headers", expectSuccess: true },
{ headers: { ...headers, foo: "bar" }, description: "with extra client headers", expectSuccess: true },
{
headers: { foo: "bar" },
description: "with non-matching client headers",
expectSuccess: Object.keys(headers).length === 0,
},
];

for (const {
headers: clientHeaders,
description: clientDescription,
expectSuccess,
} of clientHeaderTestCases) {
describe(clientDescription, () => {
let client: Client;
let transport: StreamableHTTPClientTransport;
beforeAll(() => {
client = new Client({
name: "test",
version: "0.0.0",
});
transport = new StreamableHTTPClientTransport(new URL(`${runner.address}/mcp`), {
requestInit: {
headers: clientHeaders,
},
});
});

it("handles requests and sends responses", async () => {
const response = await client.listTools();
expect(response).toBeDefined();
expect(response.tools).toBeDefined();
expect(response.tools.length).toBeGreaterThan(0);
afterAll(async () => {
await client.close();
await transport.close();
});

const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name));
expect(sortedTools[0]?.name).toBe("aggregate");
expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection");
it(`should ${expectSuccess ? "succeed" : "fail"}`, async () => {
try {
await client.connect(transport);
const response = await client.listTools();
expect(response).toBeDefined();
expect(response.tools).toBeDefined();
expect(response.tools.length).toBeGreaterThan(0);

const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name));
expect(sortedTools[0]?.name).toBe("aggregate");
expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection");
} catch (err) {
if (expectSuccess) {
throw err;
} else {
expect(err).toBeDefined();
expect(err?.toString()).toContain("HTTP 403");
}
}
});
});
}
});
});
}
});
Loading