Skip to content

Commit d261208

Browse files
authored
fix: minor tweaks to enable vscode integration MCP-134 (#467)
1 parent b5e2204 commit d261208

File tree

10 files changed

+233
-117
lines changed

10 files changed

+233
-117
lines changed

src/common/config.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ export interface UserConfig extends CliOptions {
116116
transport: "stdio" | "http";
117117
httpPort: number;
118118
httpHost: string;
119+
httpHeaders: Record<string, string>;
119120
loggers: Array<"stderr" | "disk" | "mcp">;
120121
idleTimeoutMs: number;
121122
notificationTimeoutMs: number;
@@ -137,6 +138,7 @@ export const defaultUserConfig: UserConfig = {
137138
loggers: ["disk", "mcp"],
138139
idleTimeoutMs: 600000, // 10 minutes
139140
notificationTimeoutMs: 540000, // 9 minutes
141+
httpHeaders: {},
140142
};
141143

142144
export const config = setupUserConfig({

src/common/logger.ts

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ export const LogId = {
6464
oidcFlow: mongoLogId(1_008_001),
6565
} as const;
6666

67-
interface LogPayload {
67+
export interface LogPayload {
6868
id: MongoLogId;
6969
context: string;
7070
message: string;
@@ -152,6 +152,26 @@ export abstract class LoggerBase<T extends EventMap<T> = DefaultEventMap> extend
152152
public emergency(payload: LogPayload): void {
153153
this.log("emergency", payload);
154154
}
155+
156+
protected mapToMongoDBLogLevel(level: LogLevel): "info" | "warn" | "error" | "debug" | "fatal" {
157+
switch (level) {
158+
case "info":
159+
return "info";
160+
case "warning":
161+
return "warn";
162+
case "error":
163+
return "error";
164+
case "notice":
165+
case "debug":
166+
return "debug";
167+
case "critical":
168+
case "alert":
169+
case "emergency":
170+
return "fatal";
171+
default:
172+
return "info";
173+
}
174+
}
155175
}
156176

157177
export class ConsoleLogger extends LoggerBase {
@@ -225,26 +245,6 @@ export class DiskLogger extends LoggerBase<{ initialized: [] }> {
225245

226246
this.logWriter[mongoDBLevel]("MONGODB-MCP", id, context, message, payload.attributes);
227247
}
228-
229-
private mapToMongoDBLogLevel(level: LogLevel): "info" | "warn" | "error" | "debug" | "fatal" {
230-
switch (level) {
231-
case "info":
232-
return "info";
233-
case "warning":
234-
return "warn";
235-
case "error":
236-
return "error";
237-
case "notice":
238-
case "debug":
239-
return "debug";
240-
case "critical":
241-
case "alert":
242-
case "emergency":
243-
return "fatal";
244-
default:
245-
return "info";
246-
}
247-
}
248248
}
249249

250250
export class McpLogger extends LoggerBase {
@@ -286,7 +286,11 @@ export class CompositeLogger extends LoggerBase {
286286
public log(level: LogLevel, payload: LogPayload): void {
287287
// Override the public method to avoid the base logger redacting the message payload
288288
for (const logger of this.loggers) {
289-
logger.log(level, { ...payload, attributes: { ...this.attributes, ...payload.attributes } });
289+
const attributes =
290+
Object.keys(this.attributes).length > 0 || payload.attributes
291+
? { ...this.attributes, ...payload.attributes }
292+
: undefined;
293+
logger.log(level, { ...payload, attributes });
290294
}
291295
}
292296

src/helpers/deviceId.ts

Lines changed: 22 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,63 @@
11
import { getDeviceId } from "@mongodb-js/device-id";
2-
import nodeMachineId from "node-machine-id";
2+
import * as nodeMachineId from "node-machine-id";
33
import type { LoggerBase } from "../common/logger.js";
44
import { LogId } from "../common/logger.js";
55

66
export const DEVICE_ID_TIMEOUT = 3000;
77

88
export class DeviceId {
9-
private deviceId: string | undefined = undefined;
10-
private deviceIdPromise: Promise<string> | undefined = undefined;
11-
private abortController: AbortController | undefined = undefined;
9+
private static readonly UnknownDeviceId = Promise.resolve("unknown");
10+
11+
private deviceIdPromise: Promise<string>;
12+
private abortController: AbortController;
1213
private logger: LoggerBase;
1314
private readonly getMachineId: () => Promise<string>;
1415
private timeout: number;
15-
private static instance: DeviceId | undefined = undefined;
1616

1717
private constructor(logger: LoggerBase, timeout: number = DEVICE_ID_TIMEOUT) {
1818
this.logger = logger;
1919
this.timeout = timeout;
2020
this.getMachineId = (): Promise<string> => nodeMachineId.machineId(true);
21+
this.abortController = new AbortController();
22+
23+
this.deviceIdPromise = DeviceId.UnknownDeviceId;
2124
}
2225

23-
public static create(logger: LoggerBase, timeout?: number): DeviceId {
24-
if (this.instance) {
25-
throw new Error("DeviceId instance already exists, use get() to retrieve the device ID");
26-
}
26+
private initialize(): void {
27+
this.deviceIdPromise = getDeviceId({
28+
getMachineId: this.getMachineId,
29+
onError: (reason, error) => {
30+
this.handleDeviceIdError(reason, String(error));
31+
},
32+
timeout: this.timeout,
33+
abortSignal: this.abortController.signal,
34+
});
35+
}
2736

37+
public static create(logger: LoggerBase, timeout?: number): DeviceId {
2838
const instance = new DeviceId(logger, timeout ?? DEVICE_ID_TIMEOUT);
29-
instance.setup();
30-
31-
this.instance = instance;
39+
instance.initialize();
3240

3341
return instance;
3442
}
3543

36-
private setup(): void {
37-
this.deviceIdPromise = this.calculateDeviceId();
38-
}
39-
4044
/**
4145
* Closes the device ID calculation promise and abort controller.
4246
*/
4347
public close(): void {
44-
if (this.abortController) {
45-
this.abortController.abort();
46-
this.abortController = undefined;
47-
}
48-
49-
this.deviceId = undefined;
50-
this.deviceIdPromise = undefined;
51-
DeviceId.instance = undefined;
48+
this.abortController.abort();
5249
}
5350

5451
/**
5552
* Gets the device ID, waiting for the calculation to complete if necessary.
5653
* @returns Promise that resolves to the device ID string
5754
*/
5855
public get(): Promise<string> {
59-
if (this.deviceId) {
60-
return Promise.resolve(this.deviceId);
61-
}
62-
63-
if (this.deviceIdPromise) {
64-
return this.deviceIdPromise;
65-
}
66-
67-
return this.calculateDeviceId();
68-
}
69-
70-
/**
71-
* Internal method that performs the actual device ID calculation.
72-
*/
73-
private async calculateDeviceId(): Promise<string> {
74-
if (!this.abortController) {
75-
this.abortController = new AbortController();
76-
}
77-
78-
this.deviceIdPromise = getDeviceId({
79-
getMachineId: this.getMachineId,
80-
onError: (reason, error) => {
81-
this.handleDeviceIdError(reason, String(error));
82-
},
83-
timeout: this.timeout,
84-
abortSignal: this.abortController.signal,
85-
});
86-
8756
return this.deviceIdPromise;
8857
}
8958

9059
private handleDeviceIdError(reason: string, error: string): void {
91-
this.deviceIdPromise = Promise.resolve("unknown");
60+
this.deviceIdPromise = DeviceId.UnknownDeviceId;
9261

9362
switch (reason) {
9463
case "resolutionError":

src/lib.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
export { Server, type ServerOptions } from "./server.js";
22
export { Telemetry } from "./telemetry/telemetry.js";
33
export { Session, type SessionOptions } from "./common/session.js";
4-
export type { UserConfig } from "./common/config.js";
4+
export { type UserConfig, defaultUserConfig } from "./common/config.js";
5+
export { StreamableHttpRunner } from "./transports/streamableHttp.js";
6+
export { LoggerBase } from "./common/logger.js";
7+
export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js";

src/transports/base.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ export abstract class TransportRunnerBase {
1616

1717
protected constructor(
1818
protected readonly userConfig: UserConfig,
19-
private readonly driverOptions: DriverOptions
19+
private readonly driverOptions: DriverOptions,
20+
additionalLoggers: LoggerBase[]
2021
) {
21-
const loggers: LoggerBase[] = [];
22+
const loggers: LoggerBase[] = [...additionalLoggers];
2223
if (this.userConfig.loggers.includes("stderr")) {
2324
loggers.push(new ConsoleLogger());
2425
}

src/transports/stdio.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { LoggerBase } from "../common/logger.js";
12
import { LogId } from "../common/logger.js";
23
import type { Server } from "../server.js";
34
import { TransportRunnerBase } from "./base.js";
@@ -54,8 +55,8 @@ export function createStdioTransport(): StdioServerTransport {
5455
export class StdioRunner extends TransportRunnerBase {
5556
private server: Server | undefined;
5657

57-
constructor(userConfig: UserConfig, driverOptions: DriverOptions) {
58-
super(userConfig, driverOptions);
58+
constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) {
59+
super(userConfig, driverOptions, additionalLoggers);
5960
}
6061

6162
async start(): Promise<void> {

src/transports/streamableHttp.ts

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/
44
import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js";
55
import { TransportRunnerBase } from "./base.js";
66
import type { DriverOptions, UserConfig } from "../common/config.js";
7+
import type { LoggerBase } from "../common/logger.js";
78
import { LogId } from "../common/logger.js";
89
import { randomUUID } from "crypto";
910
import { SessionStore } from "../common/sessionStore.js";
@@ -18,8 +19,20 @@ export class StreamableHttpRunner extends TransportRunnerBase {
1819
private httpServer: http.Server | undefined;
1920
private sessionStore!: SessionStore;
2021

21-
constructor(userConfig: UserConfig, driverOptions: DriverOptions) {
22-
super(userConfig, driverOptions);
22+
public get serverAddress(): string {
23+
const result = this.httpServer?.address();
24+
if (typeof result === "string") {
25+
return result;
26+
}
27+
if (typeof result === "object" && result) {
28+
return `http://${result.address}:${result.port}`;
29+
}
30+
31+
throw new Error("Server is not started yet");
32+
}
33+
34+
constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) {
35+
super(userConfig, driverOptions, additionalLoggers);
2336
}
2437

2538
async start(): Promise<void> {
@@ -32,6 +45,17 @@ export class StreamableHttpRunner extends TransportRunnerBase {
3245

3346
app.enable("trust proxy"); // needed for reverse proxy support
3447
app.use(express.json());
48+
app.use((req, res, next) => {
49+
for (const [key, value] of Object.entries(this.userConfig.httpHeaders)) {
50+
const header = req.headers[key.toLowerCase()];
51+
if (!header || header !== value) {
52+
res.status(403).send({ error: `Invalid value for header "${key}"` });
53+
return;
54+
}
55+
}
56+
57+
next();
58+
});
3559

3660
const handleSessionRequest = async (req: express.Request, res: express.Response): Promise<void> => {
3761
const sessionId = req.headers["mcp-session-id"];
@@ -142,7 +166,7 @@ export class StreamableHttpRunner extends TransportRunnerBase {
142166
this.logger.info({
143167
id: LogId.streamableHttpTransportStarted,
144168
context: "streamableHttpTransport",
145-
message: `Server started on http://${this.userConfig.httpHost}:${this.userConfig.httpPort}`,
169+
message: `Server started on ${this.serverAddress}`,
146170
noRedaction: true,
147171
});
148172
}

tests/integration/build.test.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ describe("Build Test", () => {
4141
const esmKeys = Object.keys(esmModule).sort();
4242

4343
expect(cjsKeys).toEqual(esmKeys);
44-
expect(cjsKeys).toEqual(["Server", "Session", "Telemetry"]);
44+
expect(cjsKeys).toIncludeSameMembers([
45+
"Server",
46+
"Session",
47+
"Telemetry",
48+
"StreamableHttpRunner",
49+
"defaultUserConfig",
50+
"LoggerBase",
51+
]);
4552
});
4653
});

0 commit comments

Comments
 (0)