Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
getServerUrlHash,
connectToRemoteServer,
TransportStrategy,
findAvailablePort,
} from './lib/utils'
import { StaticOAuthClientInformationFull, StaticOAuthClientMetadata } from './lib/types'
import { createLazyAuthCoordinator } from './lib/coordination'
Expand All @@ -44,8 +45,8 @@ async function runClient(
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)

// Create a lazy auth coordinator
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs)
// Create a lazy auth coordinator with dynamic port support
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs, findAvailablePort)

// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
Expand Down
96 changes: 79 additions & 17 deletions src/lib/coordination.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { checkLockfile, createLockfile, deleteLockfile, getConfigFilePath, LockfileData } from './mcp-auth-config'
import { checkLockfile, createLockfile, deleteLockfile, getConfigFilePath, LockfileData, deleteConfigFile } from './mcp-auth-config'
import { EventEmitter } from 'events'
import { Server } from 'http'
import express from 'express'
import { AddressInfo } from 'net'
import { unlinkSync } from 'fs'
import { log, debugLog, setupOAuthCallbackServerWithLongPoll } from './utils'
import { log, debugLog, setupOAuthCallbackServerWithLongPoll, findAvailablePort } from './utils'

export type AuthCoordinator = {
initializeAuth: () => Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }>
Expand Down Expand Up @@ -133,8 +133,9 @@ export function createLazyAuthCoordinator(
callbackPort: number,
events: EventEmitter,
authTimeoutMs: number,
): AuthCoordinator {
let authState: { server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean } | null = null
findAvailablePortFn?: (preferredPort?: number) => Promise<number>,
): AuthCoordinator & { getActualPort: () => number | undefined } {
let authState: { server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean; actualPort: number } | null = null

return {
initializeAuth: async () => {
Expand All @@ -148,10 +149,17 @@ export function createLazyAuthCoordinator(
debugLog('Initializing auth coordination on-demand', { serverUrlHash, callbackPort })

// Initialize auth using the existing coordinateAuth logic
authState = await coordinateAuth(serverUrlHash, callbackPort, events, authTimeoutMs)
debugLog('Auth coordination completed', { skipBrowserAuth: authState.skipBrowserAuth })
return authState
authState = await coordinateAuth(serverUrlHash, callbackPort, events, authTimeoutMs, findAvailablePortFn)
debugLog('Auth coordination completed', { skipBrowserAuth: authState.skipBrowserAuth, actualPort: authState.actualPort })

// Return without actualPort for compatibility
return {
server: authState.server,
waitForAuthCode: authState.waitForAuthCode,
skipBrowserAuth: authState.skipBrowserAuth,
}
},
getActualPort: () => authState?.actualPort,
}
}

Expand All @@ -160,14 +168,16 @@ export function createLazyAuthCoordinator(
* @param serverUrlHash The hash of the server URL
* @param callbackPort The port to use for the callback server
* @param events The event emitter to use for signaling
* @returns An object with the server, waitForAuthCode function, and a flag indicating if browser auth can be skipped
* @param findAvailablePortFn Optional function to find an available port
* @returns An object with the server, waitForAuthCode function, a flag indicating if browser auth can be skipped, and the actual port used
*/
export async function coordinateAuth(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
authTimeoutMs: number,
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }> {
findAvailablePortFn?: (preferredPort?: number) => Promise<number>,
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean; actualPort: number }> {
debugLog('Coordinating authentication', { serverUrlHash, callbackPort })

// Check for a lockfile (disabled on Windows for the time being)
Expand Down Expand Up @@ -207,6 +217,7 @@ export async function coordinateAuth(
server: dummyServer,
waitForAuthCode: dummyWaitForAuthCode,
skipBrowserAuth: true,
actualPort: callbackPort, // Use original port as we're not actually listening
}
} else {
log('Taking over authentication process...')
Expand All @@ -227,16 +238,66 @@ export async function coordinateAuth(

// Create our own lockfile
debugLog('Setting up OAuth callback server', { port: callbackPort })
const { server, waitForAuthCode, authCompletedPromise } = setupOAuthCallbackServerWithLongPoll({
port: callbackPort,
path: '/oauth/callback',
events,
authTimeoutMs,
})

// Try to set up the OAuth callback server
let server: Server
let waitForAuthCode: () => Promise<string>
let authCompletedPromise: Promise<string>
let actualPort = callbackPort

try {
const result = await setupOAuthCallbackServerWithLongPoll({
port: callbackPort,
path: '/oauth/callback',
events,
authTimeoutMs,
})
server = result.server
waitForAuthCode = result.waitForAuthCode
authCompletedPromise = result.authCompletedPromise
} catch (error: any) {
// If we get an EADDRINUSE error, it means another process is using the port
if (error.code === 'EADDRINUSE' && findAvailablePortFn) {
log(`Port ${callbackPort} is already in use. Finding an alternative port...`)
debugLog('Port conflict detected, finding alternative', { originalPort: callbackPort, error: error.message })

// Find a new available port
actualPort = await findAvailablePortFn()
log(`Using dynamically assigned port: ${actualPort}`)

// Delete the old client info since the port has changed
// This will force re-registration with the new port
await deleteConfigFile(serverUrlHash, 'client_info.json')
log('Cleared existing client registration to force re-registration with new port')

// Try again with the new port
const result = await setupOAuthCallbackServerWithLongPoll({
port: actualPort,
path: '/oauth/callback',
events,
authTimeoutMs,
})
server = result.server
waitForAuthCode = result.waitForAuthCode
authCompletedPromise = result.authCompletedPromise
} else if (error.code === 'EADDRINUSE') {
// No dynamic port function provided, fail with clear error
log(`Fatal error: Port ${callbackPort} is already in use by another process.`)
log(`This usually means another instance is already handling OAuth for this server.`)
log(`Please wait for the other instance to complete or terminate it.`)
debugLog('Port conflict detected, no dynamic port function', { port: callbackPort, error: error.message })
throw new Error(`Port ${callbackPort} is already in use. Cannot proceed with OAuth authentication.`)
} else {
// Re-throw other errors
throw error
}
}

// Get the actual port the server is running on
// Get the actual port the server is running on (in case port 0 was used)
const address = server.address() as AddressInfo
const actualPort = address.port
if (actualPort === 0) {
actualPort = address.port
}
debugLog('OAuth callback server running', { port: actualPort })

log(`Creating lockfile for server ${serverUrlHash} with process ${process.pid} on port ${actualPort}`)
Expand Down Expand Up @@ -275,5 +336,6 @@ export async function coordinateAuth(
server,
waitForAuthCode,
skipBrowserAuth: false,
actualPort,
}
}
4 changes: 2 additions & 2 deletions src/lib/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ describe('setupOAuthCallbackServerWithLongPoll', () => {

it('should use custom timeout when authTimeoutMs is provided', async () => {
const customTimeout = 5000
const result = setupOAuthCallbackServerWithLongPoll({
const result = await setupOAuthCallbackServerWithLongPoll({
port: 0, // Use any available port
path: '/oauth/callback',
events,
Expand All @@ -902,7 +902,7 @@ describe('setupOAuthCallbackServerWithLongPoll', () => {
})

it('should use default timeout when authTimeoutMs is not provided', async () => {
const result = setupOAuthCallbackServerWithLongPoll({
const result = await setupOAuthCallbackServerWithLongPoll({
port: 0, // Use any available port
path: '/oauth/callback',
events,
Expand Down
41 changes: 24 additions & 17 deletions src/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ export async function connectToRemoteServer(
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
export async function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
let authCode: string | null = null
const app = express()

Expand Down Expand Up @@ -510,33 +510,40 @@ export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServe
options.events.emit('auth-code-received', code)
})

const server = app.listen(options.port, () => {
log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
})
// Wrap server.listen in a Promise to ensure it's listening before we return
return new Promise((resolve, reject) => {
const server = app.listen(options.port, () => {
log(`OAuth callback server running at http://127.0.0.1:${options.port}`)

const waitForAuthCode = (): Promise<string> => {
return new Promise((resolve) => {
if (authCode) {
resolve(authCode)
return
}

const waitForAuthCode = (): Promise<string> => {
return new Promise((resolve) => {
if (authCode) {
resolve(authCode)
return
options.events.once('auth-code-received', (code) => {
resolve(code)
})
})
}

options.events.once('auth-code-received', (code) => {
resolve(code)
})
// Resolve with the server and related functions once it's listening
resolve({ server, authCode, waitForAuthCode, authCompletedPromise })
})
}

return { server, authCode, waitForAuthCode, authCompletedPromise }

// Handle server errors
server.on('error', reject)
})
}

/**
* Sets up an Express server to handle OAuth callbacks
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options)
export async function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
const { server, authCode, waitForAuthCode } = await setupOAuthCallbackServerWithLongPoll(options)
return { server, authCode, waitForAuthCode }
}

Expand Down
5 changes: 3 additions & 2 deletions src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
setupSignalHandlers,
getServerUrlHash,
TransportStrategy,
findAvailablePort,
} from './lib/utils'
import { StaticOAuthClientInformationFull, StaticOAuthClientMetadata } from './lib/types'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
Expand All @@ -45,8 +46,8 @@ async function runProxy(
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)

// Create a lazy auth coordinator
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs)
// Create a lazy auth coordinator with dynamic port support
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs, findAvailablePort)

// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
Expand Down