diff --git a/src/common/atlas/accessListUtils.ts b/src/common/atlas/accessListUtils.ts index dd759aba..09379217 100644 --- a/src/common/atlas/accessListUtils.ts +++ b/src/common/atlas/accessListUtils.ts @@ -1,5 +1,5 @@ import { ApiClient } from "./apiClient.js"; -import logger, { LogId } from "../logger.js"; +import { LogId } from "../logger.js"; import { ApiClientError } from "./apiClientError.js"; export const DEFAULT_ACCESS_LIST_COMMENT = "Added by MongoDB MCP Server to enable tool access"; @@ -30,7 +30,7 @@ export async function ensureCurrentIpInAccessList(apiClient: ApiClient, projectI params: { path: { groupId: projectId } }, body: [entry], }); - logger.debug({ + apiClient.logger.debug({ id: LogId.atlasIpAccessListAdded, context: "accessListUtils", message: `IP access list created: ${JSON.stringify(entry)}`, @@ -38,14 +38,14 @@ export async function ensureCurrentIpInAccessList(apiClient: ApiClient, projectI } catch (err) { if (err instanceof ApiClientError && err.response?.status === 409) { // 409 Conflict: entry already exists, log info - logger.debug({ + apiClient.logger.debug({ id: LogId.atlasIpAccessListAdded, context: "accessListUtils", message: `IP address ${entry.ipAddress} is already present in the access list for project ${projectId}.`, }); return; } - logger.warning({ + apiClient.logger.warning({ id: LogId.atlasIpAccessListAddFailure, context: "accessListUtils", message: `Error adding IP access list: ${err instanceof Error ? err.message : String(err)}`, diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index c99b43a8..3b2a4dfd 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -4,7 +4,7 @@ import { ApiClientError } from "./apiClientError.js"; import { paths, operations } from "./openapi.js"; import { CommonProperties, TelemetryEvent } from "../../telemetry/types.js"; import { packageInfo } from "../packageInfo.js"; -import logger, { LogId } from "../logger.js"; +import { LoggerBase, LogId } from "../logger.js"; import { createFetch } from "@mongodb-js/devtools-proxy-support"; import * as oauth from "oauth4webapi"; import { Request as NodeFetchRequest } from "node-fetch"; @@ -28,7 +28,7 @@ export interface AccessToken { } export class ApiClient { - private options: { + private readonly options: { baseUrl: string; userAgent: string; credentials?: { @@ -94,7 +94,10 @@ export class ApiClient { }, }; - constructor(options: ApiClientOptions) { + constructor( + options: ApiClientOptions, + public readonly logger: LoggerBase + ) { this.options = { ...options, userAgent: @@ -180,7 +183,7 @@ export class ApiClient { }; } catch (error: unknown) { const err = error instanceof Error ? error : new Error(String(error)); - logger.error({ + this.logger.error({ id: LogId.atlasConnectFailure, context: "apiClient", message: `Failed to request access token: ${err.message}`, @@ -204,7 +207,7 @@ export class ApiClient { } } catch (error: unknown) { const err = error instanceof Error ? error : new Error(String(error)); - logger.error({ + this.logger.error({ id: LogId.atlasApiRevokeFailure, context: "apiClient", message: `Failed to revoke access token: ${err.message}`, diff --git a/src/common/atlas/cluster.ts b/src/common/atlas/cluster.ts index 2e8d7f28..b9a1dc1c 100644 --- a/src/common/atlas/cluster.ts +++ b/src/common/atlas/cluster.ts @@ -1,6 +1,6 @@ import { ClusterDescription20240805, FlexClusterDescription20241113 } from "./openapi.js"; import { ApiClient } from "./apiClient.js"; -import logger, { LogId } from "../logger.js"; +import { LogId } from "../logger.js"; export interface Cluster { name?: string; @@ -87,7 +87,7 @@ export async function inspectCluster(apiClient: ApiClient, projectId: string, cl return formatFlexCluster(cluster); } catch (flexError) { const err = flexError instanceof Error ? flexError : new Error(String(flexError)); - logger.error({ + apiClient.logger.error({ id: LogId.atlasInspectFailure, context: "inspect-cluster", message: `error inspecting cluster: ${err.message}`, diff --git a/src/common/logger.ts b/src/common/logger.ts index 53421d3f..90bf97be 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -3,6 +3,7 @@ import { mongoLogId, MongoLogId, MongoLogManager, MongoLogWriter } from "mongodb import redact from "mongodb-redact"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { LoggingMessageNotification } from "@modelcontextprotocol/sdk/types.js"; +import { EventEmitter } from "events"; export type LogLevel = LoggingMessageNotification["params"]["level"]; @@ -55,12 +56,17 @@ interface LogPayload { context: string; message: string; noRedaction?: boolean | LoggerType | LoggerType[]; + attributes?: Record; } export type LoggerType = "console" | "disk" | "mcp"; -export abstract class LoggerBase { - private defaultUnredactedLogger: LoggerType = "mcp"; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type EventMap = Record | DefaultEventMap; +type DefaultEventMap = [never]; + +export abstract class LoggerBase = DefaultEventMap> extends EventEmitter { + private readonly defaultUnredactedLogger: LoggerType = "mcp"; public log(level: LogLevel, payload: LogPayload): void { // If no explicit value is supplied for unredacted loggers, default to "mcp" @@ -72,7 +78,7 @@ export abstract class LoggerBase { }); } - protected abstract type: LoggerType; + protected abstract readonly type?: LoggerType; protected abstract logCore(level: LogLevel, payload: LogPayload): void; @@ -92,7 +98,7 @@ export abstract class LoggerBase { if ( typeof noRedaction === "object" && Array.isArray(noRedaction) && - this.type !== null && + this.type && noRedaction.indexOf(this.type) !== -1 ) { // If the consumer has supplied noRedaction: array, we skip redacting if our logger @@ -103,78 +109,108 @@ export abstract class LoggerBase { return redact(message); } - info(payload: LogPayload): void { + public info(payload: LogPayload): void { this.log("info", payload); } - error(payload: LogPayload): void { + public error(payload: LogPayload): void { this.log("error", payload); } - debug(payload: LogPayload): void { + public debug(payload: LogPayload): void { this.log("debug", payload); } - notice(payload: LogPayload): void { + public notice(payload: LogPayload): void { this.log("notice", payload); } - warning(payload: LogPayload): void { + public warning(payload: LogPayload): void { this.log("warning", payload); } - critical(payload: LogPayload): void { + public critical(payload: LogPayload): void { this.log("critical", payload); } - alert(payload: LogPayload): void { + public alert(payload: LogPayload): void { this.log("alert", payload); } - emergency(payload: LogPayload): void { + public emergency(payload: LogPayload): void { this.log("emergency", payload); } } export class ConsoleLogger extends LoggerBase { - protected type: LoggerType = "console"; + protected readonly type: LoggerType = "console"; protected logCore(level: LogLevel, payload: LogPayload): void { const { id, context, message } = payload; - console.error(`[${level.toUpperCase()}] ${id.__value} - ${context}: ${message} (${process.pid})`); + console.error( + `[${level.toUpperCase()}] ${id.__value} - ${context}: ${message} (${process.pid}${this.serializeAttributes(payload.attributes)})` + ); } -} -export class DiskLogger extends LoggerBase { - private constructor(private logWriter: MongoLogWriter) { - super(); + private serializeAttributes(attributes?: Record): string { + if (!attributes || Object.keys(attributes).length === 0) { + return ""; + } + return `, ${Object.entries(attributes) + .map(([key, value]) => `${key}=${value}`) + .join(", ")}`; } +} - protected type: LoggerType = "disk"; - - static async fromPath(logPath: string): Promise { - await fs.mkdir(logPath, { recursive: true }); - - const manager = new MongoLogManager({ - directory: logPath, - retentionDays: 30, - onwarn: console.warn, - onerror: console.error, - gzip: false, - retentionGB: 1, - }); +export class DiskLogger extends LoggerBase<{ initialized: [] }> { + private bufferedMessages: { level: LogLevel; payload: LogPayload }[] = []; + private logWriter?: MongoLogWriter; - await manager.cleanupOldLogFiles(); + public constructor(logPath: string, onError: (error: Error) => void) { + super(); - const logWriter = await manager.createLogWriter(); + void this.initialize(logPath, onError); + } - return new DiskLogger(logWriter); + private async initialize(logPath: string, onError: (error: Error) => void): Promise { + try { + await fs.mkdir(logPath, { recursive: true }); + + const manager = new MongoLogManager({ + directory: logPath, + retentionDays: 30, + onwarn: console.warn, + onerror: console.error, + gzip: false, + retentionGB: 1, + }); + + await manager.cleanupOldLogFiles(); + + this.logWriter = await manager.createLogWriter(); + + for (const message of this.bufferedMessages) { + this.logCore(message.level, message.payload); + } + this.bufferedMessages = []; + this.emit("initialized"); + } catch (error: unknown) { + onError(error as Error); + } } + protected type: LoggerType = "disk"; + protected logCore(level: LogLevel, payload: LogPayload): void { + if (!this.logWriter) { + // If the log writer is not initialized, buffer the message + this.bufferedMessages.push({ level, payload }); + return; + } + const { id, context, message } = payload; const mongoDBLevel = this.mapToMongoDBLogLevel(level); - this.logWriter[mongoDBLevel]("MONGODB-MCP", id, context, message); + this.logWriter[mongoDBLevel]("MONGODB-MCP", id, context, message, payload.attributes); } private mapToMongoDBLogLevel(level: LogLevel): "info" | "warn" | "error" | "debug" | "fatal" { @@ -199,11 +235,11 @@ export class DiskLogger extends LoggerBase { } export class McpLogger extends LoggerBase { - constructor(private server: McpServer) { + public constructor(private readonly server: McpServer) { super(); } - type: LoggerType = "mcp"; + protected readonly type: LoggerType = "mcp"; protected logCore(level: LogLevel, payload: LogPayload): void { // Only log if the server is connected @@ -219,35 +255,41 @@ export class McpLogger extends LoggerBase { } export class CompositeLogger extends LoggerBase { - // This is not a real logger type - it should not be used anyway. - protected type: LoggerType = "composite" as unknown as LoggerType; + protected readonly type?: LoggerType; - private loggers: LoggerBase[] = []; + private readonly loggers: LoggerBase[] = []; + private readonly attributes: Record = {}; constructor(...loggers: LoggerBase[]) { super(); - this.setLoggers(...loggers); + this.loggers = loggers; } - setLoggers(...loggers: LoggerBase[]): void { - if (loggers.length === 0) { - throw new Error("At least one logger must be provided"); - } - this.loggers = [...loggers]; + public addLogger(logger: LoggerBase): void { + this.loggers.push(logger); } public log(level: LogLevel, payload: LogPayload): void { // Override the public method to avoid the base logger redacting the message payload for (const logger of this.loggers) { - logger.log(level, payload); + logger.log(level, { ...payload, attributes: { ...this.attributes, ...payload.attributes } }); } } protected logCore(): void { throw new Error("logCore should never be invoked on CompositeLogger"); } + + public setAttribute(key: string, value: string): void { + this.attributes[key] = value; + } } -const logger = new CompositeLogger(new ConsoleLogger()); -export default logger; +export class NullLogger extends LoggerBase { + protected type?: LoggerType; + + protected logCore(): void { + // No-op logger, does not log anything + } +} diff --git a/src/common/session.ts b/src/common/session.ts index d70f7c6e..815f9f06 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -1,6 +1,6 @@ import { ApiClient, ApiClientCredentials } from "./atlas/apiClient.js"; import { Implementation } from "@modelcontextprotocol/sdk/types.js"; -import logger, { LogId } from "./logger.js"; +import { CompositeLogger, LogId } from "./logger.js"; import EventEmitter from "events"; import { AtlasClusterConnectionInfo, @@ -16,6 +16,7 @@ export interface SessionOptions { apiClientId?: string; apiClientSecret?: string; connectionManager?: ConnectionManager; + logger: CompositeLogger; } export type SessionEvents = { @@ -34,9 +35,13 @@ export class Session extends EventEmitter { version: string; }; - constructor({ apiBaseUrl, apiClientId, apiClientSecret, connectionManager }: SessionOptions) { + public logger: CompositeLogger; + + constructor({ apiBaseUrl, apiClientId, apiClientSecret, connectionManager, logger }: SessionOptions) { super(); + this.logger = logger; + const credentials: ApiClientCredentials | undefined = apiClientId && apiClientSecret ? { @@ -45,7 +50,7 @@ export class Session extends EventEmitter { } : undefined; - this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }); + this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger); this.connectionManager = connectionManager ?? new ConnectionManager(); this.connectionManager.on("connection-succeeded", () => this.emit("connect")); @@ -70,7 +75,7 @@ export class Session extends EventEmitter { await this.connectionManager.disconnect(); } catch (err: unknown) { const error = err instanceof Error ? err : new Error(String(err)); - logger.error({ + this.logger.error({ id: LogId.mongodbDisconnectFailure, context: "session", message: `Error closing service provider: ${error.message}`, @@ -90,7 +95,7 @@ export class Session extends EventEmitter { }) .catch((err: unknown) => { const error = err instanceof Error ? err : new Error(String(err)); - logger.error({ + this.logger.error({ id: LogId.atlasDeleteDatabaseUserFailure, context: "session", message: `Error deleting previous database user: ${error.message}`, diff --git a/src/common/sessionStore.ts b/src/common/sessionStore.ts index 20ef98dd..9194c252 100644 --- a/src/common/sessionStore.ts +++ b/src/common/sessionStore.ts @@ -1,6 +1,5 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import logger, { LogId, LoggerBase, McpLogger } from "./logger.js"; +import { LogId, LoggerBase } from "./logger.js"; import { ManagedTimeout, setManagedTimeout } from "./managedTimeout.js"; export class SessionStore { @@ -15,7 +14,8 @@ export class SessionStore { constructor( private readonly idleTimeoutMS: number, - private readonly notificationTimeoutMS: number + private readonly notificationTimeoutMS: number, + private readonly logger: LoggerBase ) { if (idleTimeoutMS <= 0) { throw new Error("idleTimeoutMS must be greater than 0"); @@ -47,7 +47,7 @@ export class SessionStore { private sendNotification(sessionId: string): void { const session = this.sessions[sessionId]; if (!session) { - logger.warning({ + this.logger.warning({ id: LogId.streamableHttpTransportSessionCloseNotificationFailure, context: "sessionStore", message: `session ${sessionId} not found, no notification delivered`, @@ -61,7 +61,7 @@ export class SessionStore { }); } - setSession(sessionId: string, transport: StreamableHTTPServerTransport, mcpServer: McpServer): void { + setSession(sessionId: string, transport: StreamableHTTPServerTransport, logger: LoggerBase): void { const session = this.sessions[sessionId]; if (session) { throw new Error(`Session ${sessionId} already exists`); @@ -81,7 +81,12 @@ export class SessionStore { () => this.sendNotification(sessionId), this.notificationTimeoutMS ); - this.sessions[sessionId] = { logger: new McpLogger(mcpServer), transport, abortTimeout, notificationTimeout }; + this.sessions[sessionId] = { + transport, + abortTimeout, + notificationTimeout, + logger, + }; } async closeSession(sessionId: string, closeTransport: boolean = true): Promise { @@ -95,7 +100,7 @@ export class SessionStore { try { await session.transport.close(); } catch (error) { - logger.error({ + this.logger.error({ id: LogId.streamableHttpTransportSessionCloseFailure, context: "streamableHttpTransport", message: `Error closing transport ${sessionId}: ${error instanceof Error ? error.message : String(error)}`, diff --git a/src/index.ts b/src/index.ts index 406569b6..e94e866d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,6 @@ #!/usr/bin/env node -import logger, { LogId } from "./common/logger.js"; +import { ConsoleLogger, LogId } from "./common/logger.js"; import { config } from "./common/config.js"; import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; @@ -9,7 +9,7 @@ async function main() { const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); const shutdown = () => { - logger.info({ + transportRunner.logger.info({ id: LogId.serverCloseRequested, context: "server", message: `Server close requested`, @@ -18,7 +18,7 @@ async function main() { transportRunner .close() .then(() => { - logger.info({ + transportRunner.logger.info({ id: LogId.serverClosed, context: "server", message: `Server closed`, @@ -26,7 +26,7 @@ async function main() { process.exit(0); }) .catch((error: unknown) => { - logger.error({ + transportRunner.logger.error({ id: LogId.serverCloseFailure, context: "server", message: `Error closing server: ${error as string}`, @@ -43,20 +43,20 @@ async function main() { try { await transportRunner.start(); } catch (error: unknown) { - logger.info({ + transportRunner.logger.info({ id: LogId.serverCloseRequested, context: "server", message: "Closing server", }); try { await transportRunner.close(); - logger.info({ + transportRunner.logger.info({ id: LogId.serverClosed, context: "server", message: "Server closed", }); } catch (error: unknown) { - logger.error({ + transportRunner.logger.error({ id: LogId.serverCloseFailure, context: "server", message: `Error closing server: ${error as string}`, @@ -67,6 +67,10 @@ async function main() { } main().catch((error: unknown) => { + // At this point, we may be in a very broken state, so we can't rely on the logger + // being functional. Instead, create a brand new ConsoleLogger and log the error + // to the console. + const logger = new ConsoleLogger(); logger.emergency({ id: LogId.serverStartFailure, context: "server", diff --git a/src/resources/resource.ts b/src/resources/resource.ts index f1da56fa..f5902a80 100644 --- a/src/resources/resource.ts +++ b/src/resources/resource.ts @@ -4,7 +4,7 @@ import { UserConfig } from "../common/config.js"; import { Telemetry } from "../telemetry/telemetry.js"; import type { SessionEvents } from "../common/session.js"; import { ReadResourceCallback, ResourceMetadata } from "@modelcontextprotocol/sdk/server/mcp.js"; -import logger, { LogId } from "../common/logger.js"; +import { LogId } from "../common/logger.js"; type PayloadOf = SessionEvents[K][0]; @@ -63,7 +63,7 @@ export function ReactiveResource { this.session.setAgentRunner(this.mcpServer.server.getClientVersion()); this.session.sessionId = new ObjectId().toString(); - logger.info({ + this.session.logger.info({ id: LogId.serverInitialized, context: "server", message: `Server started with transport ${transport.constructor.name} and agent runner ${this.session.agentRunner?.name}`, diff --git a/src/telemetry/telemetry.ts b/src/telemetry/telemetry.ts index 80b430fe..d5c63bae 100644 --- a/src/telemetry/telemetry.ts +++ b/src/telemetry/telemetry.ts @@ -1,7 +1,7 @@ import { Session } from "../common/session.js"; import { BaseEvent, CommonProperties } from "./types.js"; import { UserConfig } from "../common/config.js"; -import logger, { LogId } from "../common/logger.js"; +import { LogId } from "../common/logger.js"; import { ApiClient } from "../common/atlas/apiClient.js"; import { MACHINE_METADATA } from "./constants.js"; import { EventCache } from "./eventCache.js"; @@ -63,14 +63,14 @@ export class Telemetry { onError: (reason, error) => { switch (reason) { case "resolutionError": - logger.debug({ + this.session.logger.debug({ id: LogId.telemetryDeviceIdFailure, context: "telemetry", message: String(error), }); break; case "timeout": - logger.debug({ + this.session.logger.debug({ id: LogId.telemetryDeviceIdTimeout, context: "telemetry", message: "Device ID retrieval timed out", @@ -108,7 +108,7 @@ export class Telemetry { public async emitEvents(events: BaseEvent[]): Promise { try { if (!this.isTelemetryEnabled()) { - logger.info({ + this.session.logger.info({ id: LogId.telemetryEmitFailure, context: "telemetry", message: "Telemetry is disabled.", @@ -119,7 +119,7 @@ export class Telemetry { await this.emit(events); } catch { - logger.debug({ + this.session.logger.debug({ id: LogId.telemetryEmitFailure, context: "telemetry", message: "Error emitting telemetry events.", @@ -174,7 +174,7 @@ export class Telemetry { const cachedEvents = this.eventCache.getEvents(); const allEvents = [...cachedEvents, ...events]; - logger.debug({ + this.session.logger.debug({ id: LogId.telemetryEmitStart, context: "telemetry", message: `Attempting to send ${allEvents.length} events (${cachedEvents.length} cached)`, @@ -183,7 +183,7 @@ export class Telemetry { const result = await this.sendEvents(this.session.apiClient, allEvents); if (result.success) { this.eventCache.clearEvents(); - logger.debug({ + this.session.logger.debug({ id: LogId.telemetryEmitSuccess, context: "telemetry", message: `Sent ${allEvents.length} events successfully: ${JSON.stringify(allEvents, null, 2)}`, @@ -191,7 +191,7 @@ export class Telemetry { return; } - logger.debug({ + this.session.logger.debug({ id: LogId.telemetryEmitFailure, context: "telemetry", message: `Error sending event to client: ${result.error instanceof Error ? result.error.message : String(result.error)}`, diff --git a/src/tools/atlas/atlasTool.ts b/src/tools/atlas/atlasTool.ts index 58cc5849..326c3aec 100644 --- a/src/tools/atlas/atlasTool.ts +++ b/src/tools/atlas/atlasTool.ts @@ -1,7 +1,7 @@ import { ToolBase, ToolCategory, TelemetryToolMetadata, ToolArgs } from "../tool.js"; import { ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; -import logger, { LogId } from "../../common/logger.js"; +import { LogId } from "../../common/logger.js"; import { z } from "zod"; import { ApiClientError } from "../../common/atlas/apiClientError.js"; @@ -79,7 +79,7 @@ For more information on Atlas API access roles, visit: https://www.mongodb.com/d const parsedResult = argsShape.safeParse(args[0]); if (!parsedResult.success) { - logger.debug({ + this.session.logger.debug({ id: LogId.telemetryMetadataError, context: "tool", message: `Error parsing tool arguments: ${parsedResult.error.message}`, diff --git a/src/tools/atlas/connect/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts index 92ffb158..2df76ae9 100644 --- a/src/tools/atlas/connect/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -3,7 +3,7 @@ import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; import { generateSecurePassword } from "../../../helpers/generatePassword.js"; -import logger, { LogId } from "../../../common/logger.js"; +import { LogId } from "../../../common/logger.js"; import { inspectCluster } from "../../../common/atlas/cluster.js"; import { ensureCurrentIpInAccessList } from "../../../common/atlas/accessListUtils.js"; import { AtlasClusterConnectionInfo } from "../../../common/connectionManager.js"; @@ -49,7 +49,7 @@ export class ConnectClusterTool extends AtlasToolBase { case "connected": return "connected"; case "errored": - logger.debug({ + this.session.logger.debug({ id: LogId.atlasConnectFailure, context: "atlas-connect-cluster", message: `error querying cluster: ${currentConectionState.errorReason}`, @@ -127,7 +127,7 @@ export class ConnectClusterTool extends AtlasToolBase { private async connectToCluster(connectionString: string, atlas: AtlasClusterConnectionInfo): Promise { let lastError: Error | undefined = undefined; - logger.debug({ + this.session.logger.debug({ id: LogId.atlasConnectAttempt, context: "atlas-connect-cluster", message: `attempting to connect to cluster: ${this.session.connectedAtlasCluster?.clusterName}`, @@ -146,7 +146,7 @@ export class ConnectClusterTool extends AtlasToolBase { lastError = error; - logger.debug({ + this.session.logger.debug({ id: LogId.atlasConnectFailure, context: "atlas-connect-cluster", message: `error connecting to cluster: ${error.message}`, @@ -182,7 +182,7 @@ export class ConnectClusterTool extends AtlasToolBase { }) .catch((err: unknown) => { const error = err instanceof Error ? err : new Error(String(err)); - logger.debug({ + this.session.logger.debug({ id: LogId.atlasConnectFailure, context: "atlas-connect-cluster", message: `error deleting database user: ${error.message}`, @@ -192,7 +192,7 @@ export class ConnectClusterTool extends AtlasToolBase { throw lastError; } - logger.debug({ + this.session.logger.debug({ id: LogId.atlasConnectSucceeded, context: "atlas-connect-cluster", message: `connected to cluster: ${this.session.connectedAtlasCluster?.clusterName}`, @@ -228,7 +228,7 @@ export class ConnectClusterTool extends AtlasToolBase { // try to connect for about 5 minutes asynchronously void this.connectToCluster(connectionString, atlas).catch((err: unknown) => { const error = err instanceof Error ? err : new Error(String(err)); - logger.error({ + this.session.logger.error({ id: LogId.atlasConnectFailure, context: "atlas-connect-cluster", message: `error connecting to cluster: ${error.message}`, diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index 6ff09b2c..708209f8 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -3,7 +3,7 @@ import { ToolArgs, ToolBase, ToolCategory, TelemetryToolMetadata } from "../tool import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { ErrorCodes, MongoDBError } from "../../common/errors.js"; -import logger, { LogId } from "../../common/logger.js"; +import { LogId } from "../../common/logger.js"; import { Server } from "../../server.js"; export const DbOperationArgs = { @@ -28,7 +28,7 @@ export abstract class MongoDBToolBase extends ToolBase { try { await this.connectToMongoDB(this.config.connectionString); } catch (error) { - logger.error({ + this.session.logger.error({ id: LogId.mongodbConnectFailure, context: "mongodbTool", message: `Failed to connect to MongoDB instance using the connection string from the config: ${error as string}`, diff --git a/src/tools/tool.ts b/src/tools/tool.ts index 09ab8b69..21f76357 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -2,7 +2,7 @@ import { z, type ZodRawShape, type ZodNever, AnyZodObject } from "zod"; import type { RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { CallToolResult, ToolAnnotations } from "@modelcontextprotocol/sdk/types.js"; import { Session } from "../common/session.js"; -import logger, { LogId } from "../common/logger.js"; +import { LogId } from "../common/logger.js"; import { Telemetry } from "../telemetry/telemetry.js"; import { type ToolEvent } from "../telemetry/types.js"; import { UserConfig } from "../common/config.js"; @@ -73,7 +73,7 @@ export abstract class ToolBase { const callback: ToolCallback = async (...args) => { const startTime = Date.now(); try { - logger.debug({ + this.session.logger.debug({ id: LogId.toolExecute, context: "tool", message: `Executing tool ${this.name}`, @@ -84,7 +84,7 @@ export abstract class ToolBase { await this.emitToolEvent(startTime, result, ...args).catch(() => {}); return result; } catch (error: unknown) { - logger.error({ + this.session.logger.error({ id: LogId.toolExecuteFailure, context: "tool", message: `Error executing ${this.name}: ${error as string}`, @@ -107,7 +107,7 @@ export abstract class ToolBase { const existingTool = tools[this.name]; if (!existingTool) { - logger.warning({ + this.session.logger.warning({ id: LogId.toolUpdateFailure, context: "tool", message: `Tool ${this.name} not found in update`, @@ -159,7 +159,7 @@ export abstract class ToolBase { } if (errorClarification) { - logger.debug({ + this.session.logger.debug({ id: LogId.toolDisabled, context: "tool", message: `Prevented registration of ${this.name} because ${errorClarification} is disabled in the config`, diff --git a/src/transports/base.ts b/src/transports/base.ts index cc58f750..7052f1c4 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -4,22 +4,50 @@ import { Server } from "../server.js"; import { Session } from "../common/session.js"; import { Telemetry } from "../telemetry/telemetry.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { CompositeLogger, ConsoleLogger, DiskLogger, LoggerBase, McpLogger } from "../common/logger.js"; export abstract class TransportRunnerBase { + public logger: LoggerBase; + + protected constructor(protected readonly userConfig: UserConfig) { + const loggers: LoggerBase[] = []; + if (this.userConfig.loggers.includes("stderr")) { + loggers.push(new ConsoleLogger()); + } + + if (this.userConfig.loggers.includes("disk")) { + loggers.push( + new DiskLogger(this.userConfig.logPath, (err) => { + // If the disk logger fails to initialize, we log the error to stderr and exit + console.error("Error initializing disk logger:", err); + process.exit(1); + }) + ); + } + + this.logger = new CompositeLogger(...loggers); + } + protected setupServer(userConfig: UserConfig): Server { + const mcpServer = new McpServer({ + name: packageInfo.mcpServerName, + version: packageInfo.version, + }); + + const loggers = [this.logger]; + if (userConfig.loggers.includes("mcp")) { + loggers.push(new McpLogger(mcpServer)); + } + const session = new Session({ apiBaseUrl: userConfig.apiBaseUrl, apiClientId: userConfig.apiClientId, apiClientSecret: userConfig.apiClientSecret, + logger: new CompositeLogger(...loggers), }); const telemetry = Telemetry.create(session, userConfig); - const mcpServer = new McpServer({ - name: packageInfo.mcpServerName, - version: packageInfo.version, - }); - return new Server({ mcpServer, session, diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 71930341..81141b5f 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -1,4 +1,4 @@ -import logger, { LogId } from "../common/logger.js"; +import { LogId } from "../common/logger.js"; import { Server } from "../server.js"; import { TransportRunnerBase } from "./base.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; @@ -53,11 +53,11 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; - constructor(private userConfig: UserConfig) { - super(); + constructor(userConfig: UserConfig) { + super(userConfig); } - async start() { + async start(): Promise { try { this.server = this.setupServer(this.userConfig); @@ -65,7 +65,7 @@ export class StdioRunner extends TransportRunnerBase { await this.server.connect(transport); } catch (error: unknown) { - logger.emergency({ + this.logger.emergency({ id: LogId.serverStartFailure, context: "server", message: `Fatal error running server: ${error as string}`, diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index f5381756..d0e733db 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -4,7 +4,7 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/ import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; import { UserConfig } from "../common/config.js"; -import logger, { LogId } from "../common/logger.js"; +import { LogId } from "../common/logger.js"; import { randomUUID } from "crypto"; import { SessionStore } from "../common/sessionStore.js"; @@ -14,39 +14,22 @@ const JSON_RPC_ERROR_CODE_SESSION_ID_INVALID = -32002; const JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND = -32003; const JSON_RPC_ERROR_CODE_INVALID_REQUEST = -32004; -function withErrorHandling( - fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise -) { - return (req: express.Request, res: express.Response, next: express.NextFunction) => { - fn(req, res, next).catch((error) => { - logger.error({ - id: LogId.streamableHttpTransportRequestFailure, - context: "streamableHttpTransport", - message: `Error handling request: ${error instanceof Error ? error.message : String(error)}`, - }); - res.status(400).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), - }, - }); - }); - }; -} - export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; - private sessionStore: SessionStore; + private sessionStore!: SessionStore; - constructor(private userConfig: UserConfig) { - super(); - this.sessionStore = new SessionStore(this.userConfig.idleTimeoutMs, this.userConfig.notificationTimeoutMs); + constructor(userConfig: UserConfig) { + super(userConfig); } - async start() { + async start(): Promise { const app = express(); + this.sessionStore = new SessionStore( + this.userConfig.idleTimeoutMs, + this.userConfig.notificationTimeoutMs, + this.logger + ); + app.enable("trust proxy"); // needed for reverse proxy support app.use(express.json()); @@ -88,7 +71,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { app.post( "/mcp", - withErrorHandling(async (req: express.Request, res: express.Response) => { + this.withErrorHandling(async (req: express.Request, res: express.Response) => { const sessionId = req.headers["mcp-session-id"]; if (sessionId) { await handleSessionRequest(req, res); @@ -110,13 +93,15 @@ export class StreamableHttpRunner extends TransportRunnerBase { const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID().toString(), onsessioninitialized: (sessionId) => { - this.sessionStore.setSession(sessionId, transport, server.mcpServer); + server.session.logger.setAttribute("sessionId", sessionId); + + this.sessionStore.setSession(sessionId, transport, server.session.logger); }, onsessionclosed: async (sessionId) => { try { await this.sessionStore.closeSession(sessionId, false); } catch (error) { - logger.error({ + this.logger.error({ id: LogId.streamableHttpTransportSessionCloseFailure, context: "streamableHttpTransport", message: `Error closing session: ${error instanceof Error ? error.message : String(error)}`, @@ -127,7 +112,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { transport.onclose = () => { server.close().catch((error) => { - logger.error({ + this.logger.error({ id: LogId.streamableHttpTransportCloseFailure, context: "streamableHttpTransport", message: `Error closing server: ${error instanceof Error ? error.message : String(error)}`, @@ -141,8 +126,8 @@ export class StreamableHttpRunner extends TransportRunnerBase { }) ); - app.get("/mcp", withErrorHandling(handleSessionRequest)); - app.delete("/mcp", withErrorHandling(handleSessionRequest)); + app.get("/mcp", this.withErrorHandling(handleSessionRequest)); + app.delete("/mcp", this.withErrorHandling(handleSessionRequest)); this.httpServer = await new Promise((resolve, reject) => { const result = app.listen(this.userConfig.httpPort, this.userConfig.httpHost, (err?: Error) => { @@ -154,7 +139,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { }); }); - logger.info({ + this.logger.info({ id: LogId.streamableHttpTransportStarted, context: "streamableHttpTransport", message: `Server started on http://${this.userConfig.httpHost}:${this.userConfig.httpPort}`, @@ -176,4 +161,26 @@ export class StreamableHttpRunner extends TransportRunnerBase { }), ]); } + + private withErrorHandling( + fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise + ) { + return (req: express.Request, res: express.Response, next: express.NextFunction) => { + fn(req, res, next).catch((error) => { + this.logger.error({ + id: LogId.streamableHttpTransportRequestFailure, + context: "streamableHttpTransport", + message: `Error handling request: ${error instanceof Error ? error.message : String(error)}`, + }); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + }); + }; + } } diff --git a/tests/integration/common/apiClient.test.ts b/tests/integration/common/apiClient.test.ts index 54bb040d..be627a23 100644 --- a/tests/integration/common/apiClient.test.ts +++ b/tests/integration/common/apiClient.test.ts @@ -2,6 +2,7 @@ import { afterEach, beforeEach, describe, expect, it } from "vitest"; import type { AccessToken } from "../../../src/common/atlas/apiClient.js"; import { ApiClient } from "../../../src/common/atlas/apiClient.js"; import { HTTPServerProxyTestSetup } from "../fixtures/httpsServerProxyTest.js"; +import { NullLogger } from "../../../src/common/logger.js"; describe("ApiClient integration test", () => { describe(`atlas API proxy integration`, () => { @@ -14,14 +15,17 @@ describe("ApiClient integration test", () => { await proxyTestSetup.listen(); process.env.HTTP_PROXY = `https://localhost:${proxyTestSetup.httpsProxyPort}/`; - apiClient = new ApiClient({ - baseUrl: `https://localhost:${proxyTestSetup.httpsServerPort}/`, - credentials: { - clientId: "test-client-id", - clientSecret: "test-client-secret", + apiClient = new ApiClient( + { + baseUrl: `https://localhost:${proxyTestSetup.httpsServerPort}/`, + credentials: { + clientId: "test-client-id", + clientSecret: "test-client-secret", + }, + userAgent: "test-user-agent", }, - userAgent: "test-user-agent", - }); + new NullLogger() + ); }); function withToken(accessToken: string, expired: boolean) { diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index f5a6ab7f..8a8d9dcb 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -9,6 +9,7 @@ import { Telemetry } from "../../src/telemetry/telemetry.js"; import { config } from "../../src/common/config.js"; import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; import { ConnectionManager } from "../../src/common/connectionManager.js"; +import { CompositeLogger } from "../../src/common/logger.js"; interface ParameterInfo { name: string; @@ -61,6 +62,7 @@ export function setupIntegrationTest(getUserConfig: () => UserConfig): Integrati apiClientId: userConfig.apiClientId, apiClientSecret: userConfig.apiClientSecret, connectionManager, + logger: new CompositeLogger(), }); // Mock hasValidAccessToken for tests diff --git a/tests/integration/telemetry.test.ts b/tests/integration/telemetry.test.ts index d3a944f1..9be1b0ad 100644 --- a/tests/integration/telemetry.test.ts +++ b/tests/integration/telemetry.test.ts @@ -4,6 +4,7 @@ import { Session } from "../../src/common/session.js"; import { config } from "../../src/common/config.js"; import nodeMachineId from "node-machine-id"; import { describe, expect, it } from "vitest"; +import { CompositeLogger } from "../../src/common/logger.js"; describe("Telemetry", () => { it("should resolve the actual machine ID", async () => { @@ -14,6 +15,7 @@ describe("Telemetry", () => { const telemetry = Telemetry.create( new Session({ apiBaseUrl: "", + logger: new CompositeLogger(), }), config ); diff --git a/tests/integration/tools/mongodb/metadata/listDatabases.test.ts b/tests/integration/tools/mongodb/metadata/listDatabases.test.ts index 6da75a92..74cbf2e4 100644 --- a/tests/integration/tools/mongodb/metadata/listDatabases.test.ts +++ b/tests/integration/tools/mongodb/metadata/listDatabases.test.ts @@ -21,7 +21,7 @@ describeWithMongoDB("listDatabases tool", (integration) => { const response = await integration.mcpClient().callTool({ name: "list-databases", arguments: {} }); const dbNames = getDbNames(response.content); - expect(defaultDatabases).toStrictEqual(dbNames); + expect(dbNames).toStrictEqual(defaultDatabases); }); }); diff --git a/tests/unit/accessListUtils.test.ts b/tests/unit/accessListUtils.test.ts index 6dc62b65..25a63a9b 100644 --- a/tests/unit/accessListUtils.test.ts +++ b/tests/unit/accessListUtils.test.ts @@ -2,12 +2,14 @@ import { describe, it, expect, vi } from "vitest"; import { ApiClient } from "../../src/common/atlas/apiClient.js"; import { ensureCurrentIpInAccessList, DEFAULT_ACCESS_LIST_COMMENT } from "../../src/common/atlas/accessListUtils.js"; import { ApiClientError } from "../../src/common/atlas/apiClientError.js"; +import { NullLogger } from "../../src/common/logger.js"; describe("accessListUtils", () => { it("should add the current IP to the access list", async () => { const apiClient = { getIpInfo: vi.fn().mockResolvedValue({ currentIpv4Address: "127.0.0.1" } as never), createProjectIpAccessList: vi.fn().mockResolvedValue(undefined as never), + logger: new NullLogger(), } as unknown as ApiClient; await ensureCurrentIpInAccessList(apiClient, "projectId"); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -28,6 +30,7 @@ describe("accessListUtils", () => { { message: "Conflict" } as never ) as never ), + logger: new NullLogger(), } as unknown as ApiClient; await ensureCurrentIpInAccessList(apiClient, "projectId"); // eslint-disable-next-line @typescript-eslint/unbound-method diff --git a/tests/unit/common/apiClient.test.ts b/tests/unit/common/apiClient.test.ts index 0c93f219..a9fb6682 100644 --- a/tests/unit/common/apiClient.test.ts +++ b/tests/unit/common/apiClient.test.ts @@ -1,6 +1,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { ApiClient } from "../../../src/common/atlas/apiClient.js"; import { CommonProperties, TelemetryEvent, TelemetryResult } from "../../../src/telemetry/types.js"; +import { NullLogger } from "../../../src/common/logger.js"; describe("ApiClient", () => { let apiClient: ApiClient; @@ -26,14 +27,17 @@ describe("ApiClient", () => { ]; beforeEach(() => { - apiClient = new ApiClient({ - baseUrl: "https://api.test.com", - credentials: { - clientId: "test-client-id", - clientSecret: "test-client-secret", + apiClient = new ApiClient( + { + baseUrl: "https://api.test.com", + credentials: { + clientId: "test-client-id", + clientSecret: "test-client-secret", + }, + userAgent: "test-user-agent", }, - userAgent: "test-user-agent", - }); + new NullLogger() + ); // @ts-expect-error accessing private property for testing apiClient.getAccessToken = vi.fn().mockResolvedValue("mockToken"); diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index f96952fe..592d60fe 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { Session } from "../../../src/common/session.js"; import { config } from "../../../src/common/config.js"; +import { CompositeLogger } from "../../../src/common/logger.js"; vi.mock("@mongosh/service-provider-node-driver"); const MockNodeDriverServiceProvider = vi.mocked(NodeDriverServiceProvider); @@ -12,6 +13,7 @@ describe("Session", () => { session = new Session({ apiClientId: "test-client-id", apiBaseUrl: "https://api.test.com", + logger: new CompositeLogger(), }); MockNodeDriverServiceProvider.connect = vi.fn().mockResolvedValue({} as unknown as NodeDriverServiceProvider); diff --git a/tests/unit/logger.test.ts b/tests/unit/logger.test.ts index 443494f0..6341d657 100644 --- a/tests/unit/logger.test.ts +++ b/tests/unit/logger.test.ts @@ -1,6 +1,10 @@ import { describe, beforeEach, afterEach, vi, MockInstance, it, expect } from "vitest"; -import { CompositeLogger, ConsoleLogger, LoggerType, LogId, McpLogger } from "../../src/common/logger.js"; +import { CompositeLogger, ConsoleLogger, DiskLogger, LoggerType, LogId, McpLogger } from "../../src/common/logger.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import os from "os"; +import * as path from "path"; +import * as fs from "fs/promises"; +import { once } from "events"; describe("Logger", () => { let consoleErrorSpy: MockInstance; @@ -161,4 +165,143 @@ describe("Logger", () => { }); }); }); + + describe("disk logger", () => { + let logPath: string; + beforeEach(() => { + logPath = path.join(os.tmpdir(), `mcp-logs-test-${Math.random()}-${Date.now()}`); + }); + + const assertNoLogs: () => Promise = async () => { + try { + const files = await fs.readdir(logPath); + expect(files.length).toBe(0); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (err: any) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access + if (err?.code !== "ENOENT") { + throw err; + } + } + }; + + it("buffers messages during initialization", async () => { + const diskLogger = new DiskLogger(logPath, (err) => { + expect.fail(`Disk logger should not fail to initialize: ${err}`); + }); + + diskLogger.info({ id: LogId.serverInitialized, context: "test", message: "Test message" }); + await assertNoLogs(); + + await once(diskLogger, "initialized"); + + const files = await fs.readdir(logPath); + expect(files.length).toBe(1); + const logContent = await fs.readFile(path.join(logPath, files[0] as string), "utf-8"); + expect(logContent).toContain("Test message"); + }); + + it("includes attributes in the logs", async () => { + const diskLogger = new DiskLogger(logPath, (err) => { + expect.fail(`Disk logger should not fail to initialize: ${err}`); + }); + + diskLogger.info({ + id: LogId.serverInitialized, + context: "test", + message: "Test message", + attributes: { foo: "bar" }, + }); + await assertNoLogs(); + + await once(diskLogger, "initialized"); + + const files = await fs.readdir(logPath); + expect(files.length).toBe(1); + const logContent = await fs.readFile(path.join(logPath, files[0] as string), "utf-8"); + expect(logContent).toContain("Test message"); + expect(logContent).toContain('"foo":"bar"'); + }); + }); + + describe("CompositeLogger", () => { + describe("with attributes", () => { + it("propagates attributes to child loggers", () => { + const compositeLogger = new CompositeLogger(consoleLogger, mcpLogger); + compositeLogger.setAttribute("foo", "bar"); + + compositeLogger.info({ + id: LogId.serverInitialized, + context: "test", + message: "Test message with attributes", + }); + + expect(consoleErrorSpy).toHaveBeenCalledOnce(); + expect(getLastConsoleMessage()).toContain("foo=bar"); + + expect(mcpLoggerSpy).toHaveBeenCalledOnce(); + // The MCP logger ignores attributes + expect(getLastMcpLogMessage()).not.toContain("foo=bar"); + }); + + it("merges attributes with payload attributes", () => { + const compositeLogger = new CompositeLogger(consoleLogger, mcpLogger); + compositeLogger.setAttribute("foo", "bar"); + + compositeLogger.info({ + id: LogId.serverInitialized, + context: "test", + message: "Test message with attributes", + attributes: { baz: "qux" }, + }); + + expect(consoleErrorSpy).toHaveBeenCalledOnce(); + expect(getLastConsoleMessage()).toContain("foo=bar"); + expect(getLastConsoleMessage()).toContain("baz=qux"); + + expect(mcpLoggerSpy).toHaveBeenCalledOnce(); + // The MCP logger ignores attributes + expect(getLastMcpLogMessage()).not.toContain("foo=bar"); + expect(getLastMcpLogMessage()).not.toContain("baz=qux"); + }); + + it("doesn't impact base logger's attributes", () => { + const childComposite = new CompositeLogger(consoleLogger); + const attributedComposite = new CompositeLogger(consoleLogger, childComposite); + attributedComposite.setAttribute("foo", "bar"); + + attributedComposite.info({ + id: LogId.serverInitialized, + context: "test", + message: "Test message with attributes", + }); + + // We include the console logger twice - once in the attributedComposite + // and another time in the childComposite, so we expect to have 2 console.error + // calls. + expect(consoleErrorSpy).toHaveBeenCalledTimes(2); + expect(getLastConsoleMessage()).toContain("foo=bar"); + + // The base logger should not have the attribute set + consoleLogger.debug({ + id: LogId.serverInitialized, + context: "test", + message: "Another message without attributes", + }); + + expect(consoleErrorSpy).toHaveBeenCalledTimes(3); + expect(getLastConsoleMessage()).not.toContain("foo=bar"); + + // The child composite should not have the attribute set + childComposite.error({ + id: LogId.serverInitialized, + context: "test", + message: "Another message without attributes", + }); + + expect(consoleErrorSpy).toHaveBeenCalledTimes(4); + expect(getLastConsoleMessage()).not.toContain("foo=bar"); + }); + }); + }); }); diff --git a/tests/unit/telemetry.test.ts b/tests/unit/telemetry.test.ts index c5afcdb8..d5b6e553 100644 --- a/tests/unit/telemetry.test.ts +++ b/tests/unit/telemetry.test.ts @@ -5,7 +5,7 @@ import { BaseEvent, TelemetryResult } from "../../src/telemetry/types.js"; import { EventCache } from "../../src/telemetry/eventCache.js"; import { config } from "../../src/common/config.js"; import { afterEach, beforeEach, describe, it, vi, expect } from "vitest"; -import logger, { LogId } from "../../src/common/logger.js"; +import { LogId, NullLogger } from "../../src/common/logger.js"; import { createHmac } from "crypto"; import type { MockedFunction } from "vitest"; @@ -106,7 +106,7 @@ describe("Telemetry", () => { vi.clearAllMocks(); // Setup mocked API client - mockApiClient = vi.mocked(new MockApiClient({ baseUrl: "" })); + mockApiClient = vi.mocked(new MockApiClient({ baseUrl: "" }, new NullLogger())); mockApiClient.sendEvents = vi.fn().mockResolvedValue(undefined); mockApiClient.hasCredentials = vi.fn().mockReturnValue(true); @@ -125,6 +125,7 @@ describe("Telemetry", () => { agentRunner: { name: "test-agent", version: "1.0.0" } as const, close: vi.fn().mockResolvedValue(undefined), setAgentRunner: vi.fn().mockResolvedValue(undefined), + logger: new NullLogger(), } as unknown as Session; telemetry = Telemetry.create(session, config, { @@ -236,7 +237,7 @@ describe("Telemetry", () => { }); it("should handle machine ID resolution failure", async () => { - const loggerSpy = vi.spyOn(logger, "debug"); + const loggerSpy = vi.spyOn(session.logger, "debug"); telemetry = Telemetry.create(session, config, { getRawMachineId: () => Promise.reject(new Error("Failed to get device ID")), @@ -258,7 +259,7 @@ describe("Telemetry", () => { }); it("should timeout if machine ID resolution takes too long", async () => { - const loggerSpy = vi.spyOn(logger, "debug"); + const loggerSpy = vi.spyOn(session.logger, "debug"); telemetry = Telemetry.create(session, config, { getRawMachineId: () => new Promise(() => {}) });