diff --git a/.env.example b/.env.example index 0012d25b..a77fb021 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,12 @@ PORT = 3000 GCP_STDIO = false +OAUTH_ENABLED = true +GOOGLE_OAUTH_CLIENT_ID = YOUR_CLIENT_ID +GOOGLE_OAUTH_CLIENT_SECRET = YOUR_CLIENT_SECRET +GOOGLE_OAUTH_AUDIENCE = YOUR_AUDIENCE # generally, same as your client id +GOOGLE_OAUTH_REDIRECT_URI = http://localhost:7777/oauth/callback + OAUTH_PROTECTED_RESOURCE = http://localhost:${PORT}/mcp OAUTH_AUTHORIZATION_SERVER = http://localhost:${PORT}/auth/google OAUTH_AUTHORIZATION_ENDPOINT = https://accounts.google.com/o/oauth2/v2/auth diff --git a/constants.js b/constants.js index 3f743efb..97b95bd7 100644 --- a/constants.js +++ b/constants.js @@ -5,3 +5,4 @@ export const SCOPES = { }; export const BEARER_METHODS_SUPPORTED = ['header']; export const RESPONSE_TYPES_SUPPORTED = ['code']; +export const GCLOUD_AUTH = 'gcloud_auth'; diff --git a/lib/clients.js b/lib/clients.js index 01ba449c..42bd75a6 100644 --- a/lib/clients.js +++ b/lib/clients.js @@ -14,9 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ +import { OAuth2Client } from 'google-auth-library'; +import { GCLOUD_AUTH } from '../constants.js'; + +const CLIENT_ID = process.env.GOOGLE_OAUTH_CLIENT_ID; +const CLIENT_SECRET = process.env.GOOGLE_OAUTH_CLIENT_SECRET; +const REDIRECT_URI = process.env.GOOGLE_OAUTH_REDIRECT_URI; const AUTHORIZATION_HEADER = 'Authorization'; const BEARER_PREFIX = 'Bearer'; -const GCLOUD_AUTH = 'gcloud_auth'; function keyGenerator(projectId, accessToken) { return accessToken !== GCLOUD_AUTH ? projectId + accessToken : projectId; @@ -32,6 +37,7 @@ const clients = { logging: new Map(), billing: new Map(), projects: new Map(), + oauth: new Map(), }; function getAuthClient(accessToken) { @@ -44,6 +50,47 @@ function getAuthClient(accessToken) { }; } +/** + * Wraps an OAuth2Client to be compatible with gRPC-based Google Cloud clients. + * + * Some Google Cloud SDK clients (like Cloud Run, Service Usage, etc.) use gRPC + * and expect the `getRequestHeaders()` method of the auth client to return a + * `Map` of headers. However, the standard `OAuth2Client` from `google-auth-library` + * returns a plain Javascript object (e.g., `{ Authorization: 'Bearer ...' }`). + * + * This wrapper uses a Proxy to intercept calls to `getRequestHeaders`. It calls + * the original method and converts the result from a plain object to a `Map` + * if necessary, ensuring compatibility with gRPC-based clients while maintaining + * the original behavior for other properties. + * + * @param {OAuth2Client} authClient - The original OAuth2Client instance. + * @returns {Proxy} A proxy compatible with gRPC clients. + */ +function wrapForGrpc(authClient) { + return new Proxy(authClient, { + get(target, prop, receiver) { + if (prop === 'getRequestHeaders') { + return async (...args) => { + const headers = await target.getRequestHeaders(...args); + if (headers instanceof Map) { + return headers; + } + // Convert plain object (from OAuth2Client) to Map (expected by grpc-js) + const headerMap = new Map(); + for (const [k, v] of Object.entries(headers)) { + headerMap.set(k, v); + } + return headerMap; + }; + } + return Reflect.get(target, prop, receiver); + }, + }); +} + +// Services that use HTTP/REST (or prefer Object headers) instead of gRPC (Map headers) +const HTTP_SERVICES = ['storage', 'logging']; + export async function getClient( service, key, @@ -55,7 +102,14 @@ export async function getClient( const ClientClass = await loadClient(); const finalOptions = { ...options }; if (accessToken && accessToken !== GCLOUD_AUTH) { - finalOptions.authClient = getAuthClient(accessToken); + const oauthClient = await getOAuthClient(accessToken); + // Storage and Logging use HTTP and expect the native OAuth2Client (headers as Object). + // Other services (Run, ServiceUsage, etc.) use gRPC and expect headers as Map. + if (HTTP_SERVICES.includes(service)) { + finalOptions.authClient = oauthClient; + } else { + finalOptions.authClient = wrapForGrpc(oauthClient); + } } clients[service].set(key, new ClientClass(finalOptions)); } @@ -220,3 +274,19 @@ export async function getProjectsClient(accessToken = GCLOUD_AUTH) { accessToken ); } + +/** + * Gets an OAuth2 Client for the specified access token. + * @param {string} accessToken - The access token. + * @returns {Promise} + */ +export async function getOAuthClient(accessToken) { + // Use the access token itself as the key since it's unique per session/user + if (!clients.oauth.has(accessToken)) { + const client = new OAuth2Client(CLIENT_ID, CLIENT_SECRET, REDIRECT_URI); + client.setCredentials({ access_token: accessToken }); + clients.oauth.set(accessToken, client); + } + + return clients.oauth.get(accessToken); +} diff --git a/lib/cloud-api/auth.js b/lib/cloud-api/auth.js index 0954665a..7fbe8ec5 100644 --- a/lib/cloud-api/auth.js +++ b/lib/cloud-api/auth.js @@ -23,6 +23,7 @@ import { GoogleAuth } from 'google-auth-library'; * @async * @returns {Promise} A promise that resolves to true if GCP auth are found, and false otherwise. */ +//TODO: Rename ensureGCPCredentials to ensureADCCredentials export async function ensureGCPCredentials() { console.error('Checking for Google Cloud Application Default Credentials...'); try { diff --git a/lib/cloud-api/billing.js b/lib/cloud-api/billing.js index eef766ff..48264ec9 100644 --- a/lib/cloud-api/billing.js +++ b/lib/cloud-api/billing.js @@ -30,8 +30,8 @@ import { getBillingClient } from '../clients.js'; * @returns {Promise>} A promise that resolves to an array of billing account objects, * each with 'name', 'displayName', and 'open' status. Returns an empty array on error. */ -export async function listBillingAccounts() { - const client = await getBillingClient(); +export async function listBillingAccounts(accessToken) { + const client = await getBillingClient('global', accessToken); try { const [accounts] = await client.listBillingAccounts(); if (!accounts || accounts.length === 0) { @@ -59,9 +59,10 @@ export async function listBillingAccounts() { */ export async function attachProjectToBillingAccount( projectId, - billingAccountName + billingAccountName, + accessToken ) { - const client = await getBillingClient(); + const client = await getBillingClient(projectId, accessToken); const projectName = `projects/${projectId}`; if (!projectId) { @@ -111,8 +112,8 @@ export async function attachProjectToBillingAccount( * @param {string} projectId - The ID of the project to check. * @returns {Promise} A promise that resolves to true if billing is enabled, false otherwise. */ -export async function isBillingEnabled(projectId) { - const client = await getBillingClient(); +export async function isBillingEnabled(projectId, accessToken) { + const client = await getBillingClient(projectId, accessToken); const projectName = `projects/${projectId}`; try { // getProjectBillingInfo requires cloudbilling.googleapis.com API to check billing status @@ -136,10 +137,14 @@ export async function isBillingEnabled(projectId) { * @param {string} projectId The project ID to check billing for. * @param {function} progressCallback A callback function for progress updates. */ -export async function ensureBillingEnabled(projectId, progressCallback) { - if (!(await isBillingEnabled(projectId))) { +export async function ensureBillingEnabled( + projectId, + accessToken, + progressCallback +) { + if (!(await isBillingEnabled(projectId, accessToken))) { // Billing is disabled, try to fix it. - const accounts = await listBillingAccounts(); + const accounts = await listBillingAccounts(accessToken); if (accounts && accounts.length === 1 && accounts[0].open) { // Exactly one open account found, try to attach it. @@ -151,7 +156,8 @@ export async function ensureBillingEnabled(projectId, progressCallback) { const attachmentResult = await attachProjectToBillingAccount( projectId, - account.name + account.name, + accessToken ); if (!attachmentResult || !attachmentResult.billingEnabled) { diff --git a/lib/cloud-api/build.js b/lib/cloud-api/build.js index bb297459..a230fa32 100644 --- a/lib/cloud-api/build.js +++ b/lib/cloud-api/build.js @@ -53,21 +53,23 @@ export async function triggerCloudBuild( targetRepoName, targetImageUrl, hasDockerfile, + accessToken, progressCallback ) { const serviceName = targetImageUrl.split('/').pop().split(/[:@]/)[0]; const parent = `projects/${projectId}/locations/${location}`; const serviceFullName = `${parent}/services/${serviceName}`; - const buildsClient = await getBuildsClient(projectId); - const runClient = await getRunClient(projectId); + const buildsClient = await getBuildsClient(projectId, accessToken); + const runClient = await getRunClient(projectId, accessToken); const serviceExists = await checkCloudRunServiceExists( projectId, location, serviceName, + accessToken, progressCallback ); - const cloudBuildClient = await getCloudBuildClient(projectId); - const loggingClient = await getLoggingClient(projectId); + const cloudBuildClient = await getCloudBuildClient(projectId, accessToken); + const loggingClient = await getLoggingClient(projectId, accessToken); let buildSteps; const servicePatch = { diff --git a/lib/cloud-api/helpers.js b/lib/cloud-api/helpers.js index 6afe5ffe..de0d1b8c 100644 --- a/lib/cloud-api/helpers.js +++ b/lib/cloud-api/helpers.js @@ -100,8 +100,16 @@ async function checkAndEnableApi( * @returns {Promise} A promise that resolves when the API is enabled. * @throws {Error} If the API fails to enable or if there's an issue checking its status. */ -export async function enableApiWithRetry(projectId, api, progressCallback) { - const serviceUsageClient = await getServiceUsageClient(projectId); +export async function enableApiWithRetry( + projectId, + api, + accessToken, + progressCallback +) { + const serviceUsageClient = await getServiceUsageClient( + projectId, + accessToken + ); const serviceName = `projects/${projectId}/services/${api}`; try { await checkAndEnableApi( @@ -148,20 +156,25 @@ export async function enableApiWithRetry(projectId, api, progressCallback) { * @throws {Error} If an API fails to enable or if there's an issue checking its status. * @returns {Promise} A promise that resolves when all specified APIs are enabled. */ -export async function ensureApisEnabled(projectId, apis, progressCallback) { +export async function ensureApisEnabled( + projectId, + apis, + accessToken, + progressCallback +) { for (const api of PREREQUISITE_APIS) { - await enableApiWithRetry(projectId, api, progressCallback); + await enableApiWithRetry(projectId, api, accessToken, progressCallback); } // Before enabling other APIs, ensure billing is enabled. - await ensureBillingEnabled(projectId, progressCallback); + await ensureBillingEnabled(projectId, accessToken, progressCallback); const message = 'Checking and enabling required APIs...'; console.log(message); if (progressCallback) progressCallback({ level: 'info', data: message }); for (const api of apis) { - await enableApiWithRetry(projectId, api, progressCallback); + await enableApiWithRetry(projectId, api, accessToken, progressCallback); } const successMsg = 'All required APIs are enabled.'; console.log(successMsg); diff --git a/lib/cloud-api/projects.js b/lib/cloud-api/projects.js index c3444899..95ce8f3e 100644 --- a/lib/cloud-api/projects.js +++ b/lib/cloud-api/projects.js @@ -26,8 +26,8 @@ import { getProjectsClient } from '../clients.js'; * @function listProjects * @returns {Promise>} A promise that resolves to an array of project objects, each with an 'id' property. Returns an empty array on error. */ -export async function listProjects() { - const client = await getProjectsClient(); +export async function listProjects(accessToken) { + const client = await getProjectsClient(accessToken); try { const [projects] = await client.searchProjects(); return projects.map((project) => ({ @@ -71,8 +71,8 @@ export function generateProjectId() { * @param {string} [parent] - Optional. The resource name of the parent under which the project is to be created. e.g., "organizations/123" or "folders/456". * @returns {Promise<{projectId: string}|null>} A promise that resolves to an object containing the new project's ID. */ -export async function createProject(projectId, parent) { - const client = await getProjectsClient(); +export async function createProject(projectId, parent, accessToken) { + const client = await getProjectsClient(accessToken); let projectIdToUse = projectId; if (!projectIdToUse) { @@ -112,10 +112,14 @@ export async function createProject(projectId, parent) { * @param {string} [parent] - Optional. The resource name of the parent under which the project is to be created. e.g., "organizations/123" or "folders/456". * @returns {Promise<{projectId: string, billingMessage: string}>} A promise that resolves to an object containing the project ID and a billing status message. */ -export async function createProjectAndAttachBilling(projectIdParam, parent) { +export async function createProjectAndAttachBilling( + projectIdParam, + parent, + accessToken +) { let newProject; try { - newProject = await createProject(projectIdParam, parent); + newProject = await createProject(projectIdParam, parent, accessToken); } catch (error) { throw new Error(`Failed to create project: ${error.message}`); } @@ -128,7 +132,7 @@ export async function createProjectAndAttachBilling(projectIdParam, parent) { let billingMessage = `Project ${projectId} created successfully.`; try { - const billingAccounts = await listBillingAccounts(); + const billingAccounts = await listBillingAccounts(accessToken); if (billingAccounts && billingAccounts.length > 0) { const firstBillingAccount = billingAccounts.find((acc) => acc.open); // Prefer an open account if (firstBillingAccount) { @@ -137,7 +141,8 @@ export async function createProjectAndAttachBilling(projectIdParam, parent) { ); const billingInfo = await attachProjectToBillingAccount( projectId, - firstBillingAccount.name + firstBillingAccount.name, + accessToken ); if (billingInfo && billingInfo.billingEnabled) { billingMessage += ` It has been attached to billing account ${firstBillingAccount.displayName}.`; @@ -171,8 +176,8 @@ export async function createProjectAndAttachBilling(projectIdParam, parent) { * @param {string} projectId - The ID of the project to delete. * @returns {Promise} A promise that resolves when the delete operation is initiated. */ -export async function deleteProject(projectId) { - const client = await getProjectsClient(); +export async function deleteProject(projectId, accessToken) { + const client = await getProjectsClient(accessToken); try { console.log(`Attempting to delete project with ID: ${projectId}`); await client.deleteProject({ name: `projects/${projectId}` }); diff --git a/lib/cloud-api/registry.js b/lib/cloud-api/registry.js index d8f4700a..b7825304 100644 --- a/lib/cloud-api/registry.js +++ b/lib/cloud-api/registry.js @@ -33,12 +33,16 @@ import { logAndProgress } from '../util/helpers.js'; */ export async function ensureArtifactRegistryRepoExists( projectId, + accessToken, location, repositoryId, format = 'DOCKER', progressCallback ) { - const artifactRegistryClient = await getArtifactRegistryClient(projectId); + const artifactRegistryClient = await getArtifactRegistryClient( + projectId, + accessToken + ); const parent = `projects/${projectId}/locations/${location}`; const repoPath = artifactRegistryClient.repositoryPath( projectId, diff --git a/lib/cloud-api/run.js b/lib/cloud-api/run.js index 2045d9a4..e5ed9c99 100644 --- a/lib/cloud-api/run.js +++ b/lib/cloud-api/run.js @@ -23,14 +23,14 @@ import { import { callWithRetry, ensureApisEnabled } from './helpers.js'; import { logAndProgress } from '../util/helpers.js'; -async function listCloudRunLocations(projectId) { +async function listCloudRunLocations(projectId, accessToken) { const listLocationsRequest = { name: `projects/${projectId}`, }; const availableLocations = []; try { - const runClient = await getRunClient(projectId); + const runClient = await getRunClient(projectId, accessToken); console.log('Listing Cloud Run supported locations:'); const iterable = runClient.listLocationsAsync(listLocationsRequest); for await (const location of iterable) { @@ -51,11 +51,11 @@ async function listCloudRunLocations(projectId) { * @param {string} projectId - The Google Cloud project ID. * @returns {Promise} - A promise that resolves to an object mapping region to list of service objects in that region. */ -export async function listServices(projectId) { - const runClient = await getRunClient(projectId); +export async function listServices(projectId, accessToken) { + const runClient = await getRunClient(projectId, accessToken); - await ensureApisEnabled(projectId, ['run.googleapis.com']); - const locations = await listCloudRunLocations(projectId); + await ensureApisEnabled(projectId, ['run.googleapis.com'], accessToken); + const locations = await listCloudRunLocations(projectId, accessToken); const allServices = {}; for (const location of locations) { @@ -85,8 +85,8 @@ export async function listServices(projectId) { * @param {string} serviceId - The ID of the Cloud Run service. * @returns {Promise} - A promise that resolves to the service object. */ -export async function getService(projectId, location, serviceId) { - const runClient = await getRunClient(projectId); +export async function getService(projectId, location, serviceId, accessToken) { + const runClient = await getRunClient(projectId, accessToken); const servicePath = runClient.servicePath(projectId, location, serviceId); @@ -125,10 +125,10 @@ export async function getServiceLogs( projectId, location, serviceId, + accessToken, requestOptions ) { - const loggingClient = await getLoggingClient(projectId); - + const loggingClient = await getLoggingClient(projectId, accessToken); try { const LOG_SEVERITY = 'DEFAULT'; // e.g., 'DEFAULT', 'INFO', 'WARNING', 'ERROR' const PAGE_SIZE = 100; // Number of log entries to retrieve per page @@ -226,9 +226,10 @@ export async function checkCloudRunServiceExists( projectId, location, serviceId, + accessToken, progressCallback ) { - const runClient = await getRunClient(projectId); + const runClient = await getRunClient(projectId, accessToken); const servicePath = runClient.servicePath(projectId, location, serviceId); try { await callWithRetry( diff --git a/lib/cloud-api/storage.js b/lib/cloud-api/storage.js index cc8bec0d..21babdc4 100644 --- a/lib/cloud-api/storage.js +++ b/lib/cloud-api/storage.js @@ -34,9 +34,10 @@ export async function ensureStorageBucketExists( projectId, bucketName, location = 'us', + accessToken, progressCallback ) { - const storage = await getStorageClient(projectId); + const storage = await getStorageClient(projectId, accessToken); const bucket = storage.bucket(bucketName); try { const [exists] = await callWithRetry( diff --git a/lib/deployment/deployer.js b/lib/deployment/deployer.js index 8b415c11..13814030 100644 --- a/lib/deployment/deployer.js +++ b/lib/deployment/deployer.js @@ -66,10 +66,11 @@ async function deployToCloudRun( progressCallback, skipIamCheck, deploymentType, + accessToken, zippedSourceContainer, runtime ) { - const runClient = await getRunClient(projectId); + const runClient = await getRunClient(projectId, accessToken); const parent = runClient.locationPath(projectId, location); const servicePath = runClient.servicePath(projectId, location, serviceId); const revisionName = `${serviceId}-${Date.now()}`; // Generate a unique revision name @@ -101,6 +102,7 @@ async function deployToCloudRun( projectId, location, serviceId, + accessToken, progressCallback ); @@ -245,7 +247,8 @@ async function deployWithZip( bucketName, deploymentAttrs, progressCallback, - skipIamCheck + skipIamCheck, + accessToken ) { await logAndProgress( `Attempting direct source deployment...`, @@ -276,6 +279,7 @@ async function deployWithZip( projectId, bucketName, region, + accessToken, progressCallback ); @@ -301,6 +305,7 @@ async function deployWithZip( progressCallback, skipIamCheck, DEPLOYMENT_TYPES.NO_BUILD, + accessToken, container, deploymentAttrs.runtime ); @@ -337,7 +342,8 @@ async function deployWithBuild( hasDockerfile, bucketName, progressCallback, - skipIamCheck + skipIamCheck, + accessToken ) { const imageUrl = `${region}-docker.pkg.dev/${projectId}/${DEPLOYMENT_CONFIG.REPO_NAME}/${serviceName}:${DEPLOYMENT_CONFIG.IMAGE_TAG}`; @@ -345,6 +351,7 @@ async function deployWithBuild( projectId, bucketName, region, + accessToken, progressCallback ); @@ -361,6 +368,7 @@ async function deployWithBuild( await ensureArtifactRegistryRepoExists( projectId, + accessToken, region, DEPLOYMENT_CONFIG.REPO_NAME, 'DOCKER', @@ -377,6 +385,7 @@ async function deployWithBuild( DEPLOYMENT_CONFIG.REPO_NAME, imageUrl, hasDockerfile, + accessToken, progressCallback ); @@ -389,7 +398,8 @@ async function deployWithBuild( builtImageUrl, progressCallback, skipIamCheck, - DEPLOYMENT_TYPES.WITH_BUILD + DEPLOYMENT_TYPES.WITH_BUILD, + accessToken ); await logAndProgress(`Deployment completed successfully`, progressCallback); @@ -414,6 +424,7 @@ export async function deploy({ files, progressCallback, skipIamCheck, + accessToken, }) { if (!projectId) { const errorMsg = @@ -444,6 +455,7 @@ export async function deploy({ await ensureApisEnabled( projectId, REQUIRED_APIS.SOURCE_DEPLOY, + accessToken, progressCallback ); @@ -470,7 +482,8 @@ export async function deploy({ bucketName, deploymentAttrs, progressCallback, - skipIamCheck + skipIamCheck, + accessToken ); } catch (error) { await logAndProgress( @@ -489,7 +502,8 @@ export async function deploy({ hasDockerfile, bucketName, progressCallback, - skipIamCheck + skipIamCheck, + accessToken ); } catch (error) { const deployFailedMessage = `Deployment Failed: ${error.message}`; @@ -518,6 +532,7 @@ export async function deployImage({ imageUrl, progressCallback, skipIamCheck, + accessToken, }) { if (!projectId) { const errorMsg = @@ -547,6 +562,7 @@ export async function deployImage({ await ensureApisEnabled( projectId, REQUIRED_APIS.IMAGE_DEPLOY, + accessToken, progressCallback ); @@ -562,7 +578,8 @@ export async function deployImage({ imageUrl, progressCallback, skipIamCheck, - DEPLOYMENT_TYPES.IMAGE + DEPLOYMENT_TYPES.IMAGE, + accessToken ); await logAndProgress(`Deployment completed successfully`, progressCallback); diff --git a/lib/middleware/oauth.js b/lib/middleware/oauth.js new file mode 100644 index 00000000..ca914ef7 --- /dev/null +++ b/lib/middleware/oauth.js @@ -0,0 +1,92 @@ +import { getOAuthClient } from '../clients.js'; +import { extractAccessToken } from '../util/helpers.js'; + +const TOOLS_CALL_METHOD = 'tools/call'; + +/** + * Verifies the validity of an access token and checks the audience. + * @param {string} accessToken - The access token to verify. + * @param {string} [audience] - Audience to check against the token's audience. + * @returns {Promise} - The token info if valid. + * @throws {Error} - If the token is invalid or the audience does not match. + */ +async function verifyAccessToken(accessToken, audience) { + try { + const client = await getOAuthClient(accessToken); + const tokenInfo = await client.getTokenInfo(accessToken); + + //An expired token will not have audience + if (audience && tokenInfo.aud !== audience) { + throw new Error(`Invalid audience: expected ${audience}`); + } + + console.log('Access token verified successfully.'); + return tokenInfo; + } catch (error) { + console.error('Error verifying access token:', error); + throw error; + } +} + +/** + * Ensures that a valid OAuth token is present in the request headers + * and verifies it against the configured audience. + * @param {object} req - The request object. + * @param {object} res - The response object. + * @throws {Error} - If the token is missing or invalid. + */ +async function ensureOAuthTokenInEnv(req, res) { + try { + const audience = process.env.GOOGLE_OAUTH_AUDIENCE; + console.log('Verifying token'); + if (req.headers.authorization === undefined) { + console.log('No authorization header found in request'); + throw new Error('No authorization header'); + } + console.log('Authorization header found.. Verifying token'); + await verifyAccessToken( + extractAccessToken(req.headers.authorization), + audience + ); + console.log('Token verified'); + } catch (error) { + console.error('Authentication failed: ', error); + throw error; + } +} + +/** + * Middleware to check for OAuth token if OAuth is enabled. + * If OAUTH_ENABLED is 'true', it verifies the Authorization header. + * + * @param {import('express').Request} req + * @param {import('express').Response} res + * @param {import('express').NextFunction} next + */ +export const oauthMiddleware = async (req, res, next) => { + //If OAUTH_ENABLED is not true or the request is not a tools/call, skip the middleware + if ( + process.env.OAUTH_ENABLED !== 'true' || + req.body.method !== TOOLS_CALL_METHOD + ) { + return next(); + } + + try { + await ensureOAuthTokenInEnv(req, res); + next(); + } catch (error) { + console.error('OAuth Middleware Error:', error); + // ensureOAuthTokenInEnv throws an error if auth fails. + // We catch it and send a 401 response. + res.status(401).json({ + jsonrpc: '2.0', + error: { + code: -32001, // Custom auth error code or standard -32000 + message: 'Authentication failed', + data: error.message, + }, + id: null, + }); + } +}; diff --git a/lib/util/helpers.js b/lib/util/helpers.js index 3da67b6d..9ad31ae9 100644 --- a/lib/util/helpers.js +++ b/lib/util/helpers.js @@ -40,3 +40,15 @@ export async function logAndProgress( progressCallback({ level: severity, data: message }); } } + +/** + * Extracts the access token from the Authorization header. + * @param {string} authorizationHeader - The Authorization header string. + * @returns {string | undefined} - The extracted access token or undefined if not found. + */ +export function extractAccessToken(authorizationHeader) { + if (!authorizationHeader) { + return undefined; + } + return authorizationHeader.split(' ')[1]; +} diff --git a/mcp-server.js b/mcp-server.js index b3b4c516..56658296 100755 --- a/mcp-server.js +++ b/mcp-server.js @@ -28,13 +28,19 @@ import { SetLevelRequestSchema } from '@modelcontextprotocol/sdk/types.js'; import { registerPrompts } from './prompts.js'; import { checkGCP } from './lib/cloud-api/metadata.js'; import { ensureGCPCredentials } from './lib/cloud-api/auth.js'; -import '@dotenvx/dotenvx/config'; +import { extractAccessToken } from './lib/util/helpers.js'; +import { oauthMiddleware } from './lib/middleware/oauth.js'; +import { config } from '@dotenvx/dotenvx'; import { SCOPES, + GCLOUD_AUTH, BEARER_METHODS_SUPPORTED, RESPONSE_TYPES_SUPPORTED, } from './constants.js'; +//Suppress the warning related to missing .env file in case of non-OAuth mode +config({ quiet: true, ignore: ['MISSING_ENV_FILE'] }); + const gcpInfo = await checkGCP(); let gcpCredentialsAvailable = false; @@ -68,7 +74,7 @@ const allowedHosts = process.env.ALLOWED_HOSTS ? process.env.ALLOWED_HOSTS.split(',') : undefined; -async function getServer() { +async function getServer(accessToken = GCLOUD_AUTH) { // Create an MCP server with implementation details const server = new McpServer( { @@ -102,6 +108,7 @@ async function getServer() { defaultServiceName, skipIamCheck, gcpCredentialsAvailable, + accessToken, }); } else { console.log( @@ -114,6 +121,7 @@ async function getServer() { defaultServiceName, skipIamCheck, gcpCredentialsAvailable, + accessToken, }); } @@ -154,7 +162,8 @@ if (shouldStartStdio()) { } else { // non-stdio mode console.log('Stdio transport mode is turned off.'); - gcpCredentialsAvailable = await ensureGCPCredentials(); + gcpCredentialsAvailable = + process.env.OAUTH_ENABLED === 'true' || (await ensureGCPCredentials()); const app = enableHostValidation ? createMcpExpressApp({ allowedHosts }) @@ -174,9 +183,10 @@ if (shouldStartStdio()) { getOAuthAuthorizationServer ); - app.post('/mcp', async (req, res) => { + app.post('/mcp', oauthMiddleware, async (req, res) => { console.log('/mcp Received:', req.body); - const server = await getServer(); + const accessToken = extractAccessToken(req.headers.authorization); + const server = await getServer(accessToken); try { const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: undefined, @@ -237,7 +247,8 @@ if (shouldStartStdio()) { // Legacy SSE endpoint for older clients app.get('/sse', async (req, res) => { console.log('/sse Received:', req.body); - const server = await getServer(); + const accessToken = extractAccessToken(req.headers.authorization); + const server = await getServer(accessToken); // Create SSE transport for legacy clients const transport = new SSEServerTransport('/messages', res); sseTransports[transport.sessionId] = transport; diff --git a/test/local/clients.test.js b/test/local/clients.test.js index e5e1fe0f..d10ac138 100644 --- a/test/local/clients.test.js +++ b/test/local/clients.test.js @@ -27,7 +27,9 @@ describe('getClient Helper', () => { assert.strictEqual(client.options.projectId, projectId); assert.ok(client.options.authClient); - const headers = client.options.authClient.getRequestHeaders(); + const headers = await client.options.authClient.getRequestHeaders(); + // 'run' is a gRPC service, so it should be wrapped to return a Map + assert.ok(headers instanceof Map); assert.strictEqual(headers.get('Authorization'), 'Bearer fake-token-1'); }); @@ -96,10 +98,12 @@ describe('getClient Helper', () => { assert.notStrictEqual(client1, client2); - const h1 = client1.options.authClient.getRequestHeaders(); + const h1 = await client1.options.authClient.getRequestHeaders(); + assert.ok(h1 instanceof Map); assert.strictEqual(h1.get('Authorization'), 'Bearer token-A'); - const h2 = client2.options.authClient.getRequestHeaders(); + const h2 = await client2.options.authClient.getRequestHeaders(); + assert.ok(h2 instanceof Map); assert.strictEqual(h2.get('Authorization'), 'Bearer token-B'); }); @@ -126,6 +130,32 @@ describe('getClient Helper', () => { assert.notStrictEqual(runClient, storageClient); // Different maps assert.ok(runClient.options.authClient); assert.ok(storageClient.options.authClient); + + const runHeaders = await runClient.options.authClient.getRequestHeaders(); + assert.ok(runHeaders instanceof Map, 'Run client headers should be a Map'); + + const storageHeaders = + await storageClient.options.authClient.getRequestHeaders(); + assert.ok( + !(storageHeaders instanceof Map), + 'Storage client headers should NOT be a Map' + ); + assert.strictEqual(storageHeaders.Authorization, `Bearer ${accessToken}`); + + const loggingClient = await getClient( + 'logging', + projectId + accessToken, + async () => MockClient, + { projectId }, + accessToken + ); + const loggingHeaders = + await loggingClient.options.authClient.getRequestHeaders(); + assert.ok( + !(loggingHeaders instanceof Map), + 'Logging client headers should NOT be a Map' + ); + assert.strictEqual(loggingHeaders.Authorization, `Bearer ${accessToken}`); }); test('passes additional options correctly', async () => { diff --git a/test/local/mcp-server-stdio.test.js b/test/local/mcp-server-stdio.test.js index 4256cc06..da791318 100644 --- a/test/local/mcp-server-stdio.test.js +++ b/test/local/mcp-server-stdio.test.js @@ -13,7 +13,6 @@ describe('MCP Server stdio startup', () => { stderr = ''; serverProcess = spawn('node', ['mcp-server.js'], { cwd: process.cwd(), - env: { ...process.env, GCP_STDIO: 'true' }, }); stderr = await waitForString(serverProcess.stderr, stdioMsg); }); diff --git a/test/local/mcp-server.test.js b/test/local/mcp-server.test.js index 7bf8365d..8ccc0231 100644 --- a/test/local/mcp-server.test.js +++ b/test/local/mcp-server.test.js @@ -11,6 +11,7 @@ describe('MCP Server in stdio mode', () => { transport = new StdioClientTransport({ command: 'node', args: ['mcp-server.js'], + env: { ...process.env, GCP_STDIO: 'true' }, }); client = new Client({ name: 'test-client', diff --git a/test/local/middleware/oauth.test.js b/test/local/middleware/oauth.test.js new file mode 100644 index 00000000..e60f5e67 --- /dev/null +++ b/test/local/middleware/oauth.test.js @@ -0,0 +1,156 @@ +import { describe, it, mock, beforeEach, afterEach } from 'node:test'; +import assert from 'node:assert/strict'; +import esmock from 'esmock'; + +describe('oauthMiddleware', () => { + let req; + let res; + let next; + let originalEnv; + + beforeEach(() => { + originalEnv = process.env; + process.env = { ...originalEnv }; + req = { + headers: {}, + body: {}, + }; + res = { + headersSent: false, + status: mock.fn(() => res), + json: mock.fn(), + }; + next = mock.fn(); + }); + + afterEach(() => { + process.env = originalEnv; + mock.restoreAll(); + }); + + it('should call next() if OAUTH_ENABLED is not "true"', async () => { + process.env.OAUTH_ENABLED = 'false'; + const { oauthMiddleware } = await esmock( + '../../../lib/middleware/oauth.js', + {} + ); + + await oauthMiddleware(req, res, next); + + assert.strictEqual(next.mock.callCount(), 1); + assert.strictEqual(res.status.mock.callCount(), 0); + }); + + it('should call next() if method is not tools/call', async () => { + process.env.OAUTH_ENABLED = 'true'; + req.body.method = 'other/method'; + + // We don't verify token if it's not a tool call (based on current implementation) + // So we don't need to mock successful verification here + const { oauthMiddleware } = await esmock( + '../../../lib/middleware/oauth.js', + {} + ); + + await oauthMiddleware(req, res, next); + + assert.strictEqual(next.mock.callCount(), 1); + }); + + it('should return 401 if Authorization header is missing for tool call', async () => { + process.env.OAUTH_ENABLED = 'true'; + req.body.method = 'tools/call'; + + const { oauthMiddleware } = await esmock( + '../../../lib/middleware/oauth.js', + {} + ); + + await oauthMiddleware(req, res, next); + + assert.strictEqual(next.mock.callCount(), 0); + assert.strictEqual(res.status.mock.callCount(), 1); + assert.deepStrictEqual(res.status.mock.calls[0].arguments, [401]); + assert.strictEqual(res.json.mock.callCount(), 1); + }); + + it('should verify token and call next() for valid tool call', async () => { + process.env.OAUTH_ENABLED = 'true'; + req.body.method = 'tools/call'; + req.headers.authorization = 'Bearer valid-token'; + + const mockGetTokenInfo = mock.fn(async () => ({ aud: 'valid-audience' })); + const mockGetOAuthClient = mock.fn(async () => ({ + getTokenInfo: mockGetTokenInfo, + })); + + const { oauthMiddleware } = await esmock( + '../../../lib/middleware/oauth.js', + { + '../../../lib/clients.js': { + getOAuthClient: mockGetOAuthClient, + }, + } + ); + + await oauthMiddleware(req, res, next); + + assert.strictEqual(next.mock.callCount(), 1); + assert.strictEqual(mockGetOAuthClient.mock.callCount(), 1); + assert.strictEqual(mockGetTokenInfo.mock.callCount(), 1); + }); + + it('should return 401 if token verification fails', async () => { + process.env.OAUTH_ENABLED = 'true'; + req.body.method = 'tools/call'; + req.headers.authorization = 'Bearer invalid-token'; + + const mockGetOAuthClient = mock.fn(async () => ({ + getTokenInfo: async () => { + throw new Error('Invalid token'); + }, + })); + + const { oauthMiddleware } = await esmock( + '../../../lib/middleware/oauth.js', + { + '../../../lib/clients.js': { + getOAuthClient: mockGetOAuthClient, + }, + } + ); + + await oauthMiddleware(req, res, next); + + assert.strictEqual(next.mock.callCount(), 0); + assert.strictEqual(res.status.mock.callCount(), 1); + assert.deepStrictEqual(res.status.mock.calls[0].arguments, [401]); + }); + + it('should return 401 if audience does not match', async () => { + process.env.OAUTH_ENABLED = 'true'; + process.env.GOOGLE_OAUTH_AUDIENCE = 'expected-audience'; + req.body.method = 'tools/call'; + req.headers.authorization = 'Bearer valid-token-wrong-audience'; + + const mockGetTokenInfo = mock.fn(async () => ({ aud: 'wrong-audience' })); + const mockGetOAuthClient = mock.fn(async () => ({ + getTokenInfo: mockGetTokenInfo, + })); + + const { oauthMiddleware } = await esmock( + '../../../lib/middleware/oauth.js', + { + '../../../lib/clients.js': { + getOAuthClient: mockGetOAuthClient, + }, + } + ); + + await oauthMiddleware(req, res, next); + + assert.strictEqual(next.mock.callCount(), 0); + assert.strictEqual(res.status.mock.callCount(), 1); + assert.deepStrictEqual(res.status.mock.calls[0].arguments, [401]); + }); +}); diff --git a/test/local/tools.test.js b/test/local/tools.test.js index 5afebf23..86d59919 100644 --- a/test/local/tools.test.js +++ b/test/local/tools.test.js @@ -42,7 +42,7 @@ describe('registerTools', () => { {}, { '../../lib/cloud-api/projects.js': { - listProjects: () => + listProjects: (token) => Promise.resolve([{ id: 'project1' }, { id: 'project2' }]), }, } @@ -433,7 +433,7 @@ describe('registerTools', () => { content: [ { type: 'text', - text: 'GCP credentials are not available. Please configure your environment.', + text: 'GCP credentials are not available. Please configure your environment using OAuth or `gcloud auth`.', }, ], }); diff --git a/tools/register-tools.js b/tools/register-tools.js index 87ff87f4..53f4969b 100644 --- a/tools/register-tools.js +++ b/tools/register-tools.js @@ -41,7 +41,7 @@ function gcpTool(gcpCredentialsAvailable, fn) { content: [ { type: 'text', - text: 'GCP credentials are not available. Please configure your environment.', + text: 'GCP credentials are not available. Please configure your environment using OAuth or `gcloud auth`.', }, ], }); @@ -59,7 +59,7 @@ function registerListProjectsTool(server, options) { }, gcpTool(options.gcpCredentialsAvailable, async () => { try { - const projects = await listProjects(); + const projects = await listProjects(options.accessToken); return { content: [ { @@ -113,7 +113,11 @@ function registerCreateProjectTool(server, options) { }; } try { - const result = await createProjectAndAttachBilling(projectId); + const result = await createProjectAndAttachBilling( + projectId, + undefined, + options.accessToken + ); return { content: [ { @@ -161,7 +165,7 @@ function registerListServicesTool(server, options) { }; } try { - const allServices = await listServices(project); + const allServices = await listServices(project, options.accessToken); const content = []; for (const region of Object.keys(allServices)) { const serviceList = allServices[region]; @@ -230,7 +234,12 @@ function registerGetServiceTool(server, options) { }; } try { - const serviceDetails = await getService(project, region, service); + const serviceDetails = await getService( + project, + region, + service, + options.accessToken + ); if (serviceDetails) { return { content: [ @@ -299,6 +308,7 @@ function registerGetServiceLogTool(server, options) { project, region, service, + options.accessToken, requestOptions ); @@ -394,6 +404,7 @@ function registerDeployLocalFolderTool(server, options) { region: region, files: [folderPath], skipIamCheck: options.skipIamCheck, // Pass the new flag + accessToken: options.accessToken, progressCallback, }); return { @@ -495,6 +506,7 @@ function registerDeployFileContentsTool(server, options) { region: region, files: files, skipIamCheck: options.skipIamCheck, // Pass the new flag + accessToken: options.accessToken, progressCallback, }); return { @@ -575,6 +587,7 @@ function registerDeployContainerImageTool(server, options) { region: region, imageUrl: imageUrl, skipIamCheck: options.skipIamCheck, + accessToken: options.accessToken, progressCallback, }); return {