diff --git a/package-lock.json b/package-lock.json index 83af188c..307ef728 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,6 +17,7 @@ "mongodb-log-writer": "^2.4.1", "mongodb-redact": "^1.1.6", "mongodb-schema": "^12.6.2", + "yargs-parser": "^21.1.1", "zod": "^3.24.2" }, "bin": { @@ -29,6 +30,7 @@ "@redocly/cli": "^1.34.2", "@types/node": "^22.14.0", "@types/simple-oauth2": "^5.0.7", + "@types/yargs-parser": "^21.0.3", "eslint": "^9.24.0", "eslint-config-prettier": "^10.1.1", "globals": "^16.0.0", @@ -4769,6 +4771,13 @@ "@types/webidl-conversions": "*" } }, + "node_modules/@types/yargs-parser": { + "version": "21.0.3", + "resolved": "https://registry.npmjs.org/@types/yargs-parser/-/yargs-parser-21.0.3.tgz", + "integrity": "sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==", + "dev": true, + "license": "MIT" + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.29.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.29.1.tgz", @@ -11218,7 +11227,6 @@ "version": "21.1.1", "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", - "devOptional": true, "license": "ISC", "engines": { "node": ">=12" diff --git a/package.json b/package.json index 94a85898..62f82c4a 100644 --- a/package.json +++ b/package.json @@ -38,6 +38,7 @@ "@redocly/cli": "^1.34.2", "@types/node": "^22.14.0", "@types/simple-oauth2": "^5.0.7", + "@types/yargs-parser": "^21.0.3", "eslint": "^9.24.0", "eslint-config-prettier": "^10.1.1", "globals": "^16.0.0", @@ -58,6 +59,7 @@ "mongodb-log-writer": "^2.4.1", "mongodb-redact": "^1.1.6", "mongodb-schema": "^12.6.2", + "yargs-parser": "^21.1.1", "zod": "^3.24.2" }, "engines": { diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index 6d745792..87ce4b05 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -91,7 +91,7 @@ export class ApiClient { throw new Error("Not authenticated. Please run the auth tool first."); } - const url = new URL(`api/atlas/v2${endpoint}`, `${config.apiBaseURL}`); + const url = new URL(`api/atlas/v2${endpoint}`, `${config.apiBaseUrl}`); if (!this.checkTokenExpiry()) { await this.refreshToken(); @@ -119,7 +119,7 @@ export class ApiClient { async authenticate(): Promise { const endpoint = "api/private/unauth/account/device/authorize"; - const authUrl = new URL(endpoint, config.apiBaseURL); + const authUrl = new URL(endpoint, config.apiBaseUrl); const response = await fetch(authUrl, { method: "POST", @@ -128,7 +128,7 @@ export class ApiClient { Accept: "application/json", }, body: new URLSearchParams({ - client_id: config.clientID, + client_id: config.clientId, scope: "openid profile offline_access", grant_type: "urn:ietf:params:oauth:grant-type:device_code", }).toString(), @@ -143,14 +143,14 @@ export class ApiClient { async retrieveToken(device_code: string): Promise { const endpoint = "api/private/unauth/account/device/token"; - const url = new URL(endpoint, config.apiBaseURL); + const url = new URL(endpoint, config.apiBaseUrl); const response = await fetch(url, { method: "POST", headers: { "Content-Type": "application/x-www-form-urlencoded", }, body: new URLSearchParams({ - client_id: config.clientID, + client_id: config.clientId, device_code: device_code, grant_type: "urn:ietf:params:oauth:grant-type:device_code", }).toString(), @@ -179,7 +179,7 @@ export class ApiClient { async refreshToken(token?: OAuthToken): Promise { const endpoint = "api/private/unauth/account/device/token"; - const url = new URL(endpoint, config.apiBaseURL); + const url = new URL(endpoint, config.apiBaseUrl); const response = await fetch(url, { method: "POST", headers: { @@ -187,7 +187,7 @@ export class ApiClient { Accept: "application/json", }, body: new URLSearchParams({ - client_id: config.clientID, + client_id: config.clientId, refresh_token: (token || this.token)?.refresh_token || "", grant_type: "refresh_token", scope: "openid profile offline_access", @@ -213,7 +213,7 @@ export class ApiClient { async revokeToken(token?: OAuthToken): Promise { const endpoint = "api/private/unauth/account/device/token"; - const url = new URL(endpoint, config.apiBaseURL); + const url = new URL(endpoint, config.apiBaseUrl); const response = await fetch(url, { method: "POST", headers: { @@ -222,7 +222,7 @@ export class ApiClient { "User-Agent": config.userAgent, }, body: new URLSearchParams({ - client_id: config.clientID, + client_id: config.clientId, token: (token || this.token)?.access_token || "", token_type_hint: "refresh_token", }).toString(), diff --git a/src/config.ts b/src/config.ts index e9be53f5..0357101c 100644 --- a/src/config.ts +++ b/src/config.ts @@ -1,27 +1,97 @@ import path from "path"; import os from "os"; +import argv from "yargs-parser"; + import packageJson from "../package.json" with { type: "json" }; +import fs from "fs"; +const { localDataPath, configPath } = getLocalDataPath(); + +// If we decide to support non-string config options, we'll need to extend the mechanism for parsing +// env variables. +interface UserConfig extends Record { + apiBaseUrl: string; + clientId: string; + stateFile: string; +} -export const config = { +const defaults: UserConfig = { + apiBaseUrl: "https://cloud.mongodb.com/", + clientId: "0oabtxactgS3gHIR0297", + stateFile: path.join(localDataPath, "state.json"), +}; + +const mergedUserConfig = { + ...defaults, + ...getFileConfig(), + ...getEnvConfig(), + ...getCliConfig(), +}; + +const config = { + ...mergedUserConfig, atlasApiVersion: `2025-03-12`, version: packageJson.version, - apiBaseURL: process.env.API_BASE_URL || "https://cloud.mongodb.com/", - clientID: process.env.CLIENT_ID || "0oabtxactgS3gHIR0297", - stateFile: process.env.STATE_FILE || path.resolve("./state.json"), userAgent: `AtlasMCP/${packageJson.version} (${process.platform}; ${process.arch}; ${process.env.HOSTNAME || "unknown"})`, - localDataPath: getLocalDataPath(), + localDataPath, }; export default config; -function getLocalDataPath() { +function getLocalDataPath(): { localDataPath: string; configPath: string } { + let localDataPath: string | undefined; + let configPath: string | undefined; + if (process.platform === "win32") { const appData = process.env.APPDATA; const localAppData = process.env.LOCALAPPDATA ?? process.env.APPDATA; if (localAppData && appData) { - return path.join(localAppData, "mongodb", "mongodb-mcp"); + localDataPath = path.join(localAppData, "mongodb", "mongodb-mcp"); + configPath = path.join(localDataPath, "mongodb-mcp.conf"); } } - return path.join(os.homedir(), ".mongodb", "mongodb-mcp"); + localDataPath ??= path.join(os.homedir(), ".mongodb", "mongodb-mcp"); + configPath ??= "/etc/mongodb-mcp.conf"; + + fs.mkdirSync(localDataPath, { recursive: true }); + + return { + localDataPath, + configPath, + }; +} + +// Gets the config supplied by the user as environment variables. The variable names +// are prefixed with `MDB_MCP_` and the keys match the UserConfig keys, but are converted +// to SNAKE_UPPER_CASE. +function getEnvConfig(): Partial { + const camelCaseToSNAKE_UPPER_CASE = (str: string): string => { + return str.replace(/([a-z])([A-Z])/g, "$1_$2").toUpperCase(); + }; + + const result: Partial = {}; + for (const key of Object.keys(defaults)) { + const envVarName = `MDB_MCP_${camelCaseToSNAKE_UPPER_CASE(key)}`; + if (process.env[envVarName]) { + result[key] = process.env[envVarName]; + } + } + + return result; +} + +// Gets the config supplied by the user as a JSON file. The file is expected to be located in the local data path +// and named `config.json`. +function getFileConfig(): Partial { + try { + const config = fs.readFileSync(configPath, "utf8"); + return JSON.parse(config); + } catch { + return {}; + } +} + +// Reads the cli args and parses them into a UserConfig object. +function getCliConfig() { + return argv(process.argv.slice(2)) as unknown as Partial; } diff --git a/src/server.ts b/src/server.ts index b72b69ba..6ccb92f5 100644 --- a/src/server.ts +++ b/src/server.ts @@ -4,7 +4,7 @@ import { State, saveState, loadState } from "./state.js"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import { registerAtlasTools } from "./tools/atlas/tools.js"; import { registerMongoDBTools } from "./tools/mongodb/index.js"; -import { config } from "./config.js"; +import config from "./config.js"; import logger, { initializeLogger } from "./logger.js"; import { mongoLogId } from "mongodb-log-writer"; @@ -21,14 +21,14 @@ export class Server { this.apiClient = new ApiClient({ token: this.state?.auth.token, - saveToken: (token) => { + saveToken: async (token) => { if (!this.state) { throw new Error("State is not initialized"); } this.state.auth.code = undefined; this.state.auth.token = token; this.state.auth.status = "issued"; - saveState(this.state); + await saveState(this.state); }, }); diff --git a/src/state.ts b/src/state.ts index b020479e..349dd04c 100644 --- a/src/state.ts +++ b/src/state.ts @@ -1,4 +1,4 @@ -import fs from "fs"; +import fs from "fs/promises"; import config from "./config.js"; import { OauthDeviceCode, OAuthToken } from "./common/atlas/apiClient.js"; @@ -8,37 +8,26 @@ export interface State { code?: OauthDeviceCode; token?: OAuthToken; }; + connectionString?: string; } export async function saveState(state: State): Promise { - return new Promise((resolve, reject) => { - fs.writeFile(config.stateFile, JSON.stringify(state), function (err) { - if (err) { - return reject(err); - } - - return resolve(); - }); - }); + await fs.writeFile(config.stateFile, JSON.stringify(state), { encoding: "utf-8" }); } -export async function loadState() { - return new Promise((resolve, reject) => { - fs.readFile(config.stateFile, "utf-8", (err, data) => { - if (err) { - if (err.code === "ENOENT") { - // File does not exist, return default state - const defaultState: State = { - auth: { - status: "not_auth", - }, - }; - return resolve(defaultState); - } else { - return reject(err); - } - } - return resolve(JSON.parse(data) as State); - }); - }); +export async function loadState(): Promise { + try { + const data = await fs.readFile(config.stateFile, "utf-8"); + return JSON.parse(data) as State; + } catch (err: unknown) { + if (err && typeof err === "object" && "code" in err && err.code === "ENOENT") { + return { + auth: { + status: "not_auth", + }, + }; + } + + throw err; + } } diff --git a/src/tools/atlas/auth.ts b/src/tools/atlas/auth.ts index e6964a11..84fe8527 100644 --- a/src/tools/atlas/auth.ts +++ b/src/tools/atlas/auth.ts @@ -11,7 +11,7 @@ export class AuthTool extends AtlasToolBase { protected argsShape = {}; private async isAuthenticated(): Promise { - return isAuthenticated(this.state!, this.apiClient); + return isAuthenticated(this.state, this.apiClient); } async execute(): Promise { @@ -25,11 +25,11 @@ export class AuthTool extends AtlasToolBase { try { const code = await this.apiClient.authenticate(); - this.state!.auth.status = "requested"; - this.state!.auth.code = code; - this.state!.auth.token = undefined; + this.state.auth.status = "requested"; + this.state.auth.code = code; + this.state.auth.token = undefined; - await saveState(this.state!); + await saveState(this.state); return { content: [ diff --git a/src/tools/mongodb/connect.ts b/src/tools/mongodb/connect.ts index 76358454..397d4552 100644 --- a/src/tools/mongodb/connect.ts +++ b/src/tools/mongodb/connect.ts @@ -4,6 +4,7 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver import { DbOperationType, MongoDBToolBase } from "./mongodbTool.js"; import { ToolArgs } from "../tool.js"; import { ErrorCodes, MongoDBError } from "../../errors.js"; +import { saveState } from "../../state.js"; export class ConnectTool extends MongoDBToolBase { protected name = "connect"; @@ -20,8 +21,8 @@ export class ConnectTool extends MongoDBToolBase { protected async execute({ connectionStringOrClusterName, }: ToolArgs): Promise { + connectionStringOrClusterName ??= this.state.connectionString; if (!connectionStringOrClusterName) { - // TODO: try reconnecting to the default connection return { content: [ { type: "text", text: "No connection details provided." }, @@ -71,5 +72,7 @@ export class ConnectTool extends MongoDBToolBase { }); this.mongodbState.serviceProvider = provider; + this.state.connectionString = connectionString; + await saveState(this.state); } }