Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ You can specify multiple `--ignore-tool` flags to ignore different patterns. Exa
- `*account` - ignores all tools ending with "account" (e.g., `getAccount`, `updateAccount`)
- `exactTool` - ignores only the tool named exactly "exactTool"

* To change the timeout for the OAuth callback (by default `30` seconds), add the `--auth-timeout` flag with a value in seconds. This is useful if the authentication process on the server side takes a long time.

```json
"args": [
"mcp-remote",
"https://remote.mcp.server/sse",
"--auth-timeout",
"60"
]
```

### Transport Strategies

MCP Remote supports different transport strategies when connecting to an MCP server. This allows you to control whether it uses Server-Sent Events (SSE) or HTTP transport, and in what order it tries them.
Expand Down
20 changes: 16 additions & 4 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async function runClient(
host: string,
staticOAuthClientMetadata: StaticOAuthClientMetadata,
staticOAuthClientInfo: StaticOAuthClientInformationFull,
authTimeoutMs: number,
) {
// Set up event emitter for auth flow
const events = new EventEmitter()
Expand All @@ -44,7 +45,7 @@ async function runClient(
const serverUrlHash = getServerUrlHash(serverUrl)

// Create a lazy auth coordinator
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs)

// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
Expand Down Expand Up @@ -159,9 +160,20 @@ async function runClient(

// Parse command-line arguments and run the client
parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx client.ts <https://server-url> [callback-port] [--debug]')
.then(({ serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo }) => {
return runClient(serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo)
})
.then(
({ serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo, authTimeoutMs }) => {
return runClient(
serverUrl,
callbackPort,
headers,
transportStrategy,
host,
staticOAuthClientMetadata,
staticOAuthClientInfo,
authTimeoutMs,
)
},
)
.catch((error) => {
console.error('Fatal error:', error)
process.exit(1)
Expand Down
11 changes: 9 additions & 2 deletions src/lib/coordination.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ export async function waitForAuthentication(port: number): Promise<boolean> {
* @param events The event emitter to use for signaling
* @returns An AuthCoordinator object with an initializeAuth method
*/
export function createLazyAuthCoordinator(serverUrlHash: string, callbackPort: number, events: EventEmitter): AuthCoordinator {
export function createLazyAuthCoordinator(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
authTimeoutMs: number,
): AuthCoordinator {
let authState: { server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean } | null = null

return {
Expand All @@ -144,7 +149,7 @@ export function createLazyAuthCoordinator(serverUrlHash: string, callbackPort: n
if (DEBUG) debugLog('Initializing auth coordination on-demand', { serverUrlHash, callbackPort })

// Initialize auth using the existing coordinateAuth logic
authState = await coordinateAuth(serverUrlHash, callbackPort, events)
authState = await coordinateAuth(serverUrlHash, callbackPort, events, authTimeoutMs)
if (DEBUG) debugLog('Auth coordination completed', { skipBrowserAuth: authState.skipBrowserAuth })
return authState
},
Expand All @@ -162,6 +167,7 @@ export async function coordinateAuth(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
authTimeoutMs: number,
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }> {
if (DEBUG) debugLog('Coordinating authentication', { serverUrlHash, callbackPort })

Expand Down Expand Up @@ -228,6 +234,7 @@ export async function coordinateAuth(
port: callbackPort,
path: '/oauth/callback',
events,
authTimeoutMs,
})

// Get the actual port the server is running on
Expand Down
2 changes: 2 additions & 0 deletions src/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ export interface OAuthCallbackServerOptions {
path: string
/** Event emitter to signal when auth code is received */
events: EventEmitter
/** Timeout in milliseconds for the auth callback server's long poll */
authTimeoutMs?: number
}

// optional tatic OAuth client information
Expand Down
146 changes: 144 additions & 2 deletions src/lib/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { describe, it, expect, vi } from 'vitest'
import { parseCommandLineArgs, shouldIncludeTool, mcpProxy } from './utils'
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { parseCommandLineArgs, shouldIncludeTool, mcpProxy, setupOAuthCallbackServerWithLongPoll } from './utils'
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { EventEmitter } from 'events'
import express from 'express'

// All sanitizeUrl tests have been moved to the strict-url-sanitise package

Expand Down Expand Up @@ -322,6 +324,100 @@ describe('Feature: Command Line Arguments Parsing', () => {
expect(result.transportStrategy).toBe('sse-only')
expect(result.ignoredTools).toEqual(['tool1', 'tool2'])
})

it('Scenario: Use default auth timeout when not specified', async () => {
// Given command line arguments without --auth-timeout flag
const args = ['https://example.com/sse']
const usage = 'test usage'

// When parsing the command line arguments
const result = await parseCommandLineArgs(args, usage)

// Then the default auth timeout should be 30000ms
expect(result.authTimeoutMs).toBe(30000)
})

it('Scenario: Parse valid auth timeout in seconds and convert to milliseconds', async () => {
// Given command line arguments with valid --auth-timeout
const args = ['https://example.com/sse', '--auth-timeout', '60']
const usage = 'test usage'

// When parsing the command line arguments
const result = await parseCommandLineArgs(args, usage)

// Then the timeout should be converted to milliseconds
expect(result.authTimeoutMs).toBe(60000)
})

it('Scenario: Use default timeout when invalid auth timeout value is provided', async () => {
// Given command line arguments with invalid --auth-timeout value
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
const args = ['https://example.com/sse', '--auth-timeout', 'invalid']
const usage = 'test usage'

// When parsing the command line arguments
const result = await parseCommandLineArgs(args, usage)

// Then the default timeout should be used and warning logged
expect(result.authTimeoutMs).toBe(30000)
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining('Warning: Ignoring invalid auth timeout value: invalid. Must be a positive number.'),
)

consoleSpy.mockRestore()
})

it('Scenario: Use default timeout when negative auth timeout value is provided', async () => {
// Given command line arguments with negative --auth-timeout value
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
const args = ['https://example.com/sse', '--auth-timeout', '-30']
const usage = 'test usage'

// When parsing the command line arguments
const result = await parseCommandLineArgs(args, usage)

// Then the default timeout should be used and warning logged
expect(result.authTimeoutMs).toBe(30000)
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining('Warning: Ignoring invalid auth timeout value: -30. Must be a positive number.'),
)

consoleSpy.mockRestore()
})

it('Scenario: Use default timeout when zero auth timeout value is provided', async () => {
// Given command line arguments with zero --auth-timeout value
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
const args = ['https://example.com/sse', '--auth-timeout', '0']
const usage = 'test usage'

// When parsing the command line arguments
const result = await parseCommandLineArgs(args, usage)

// Then the default timeout should be used and warning logged
expect(result.authTimeoutMs).toBe(30000)
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining('Warning: Ignoring invalid auth timeout value: 0. Must be a positive number.'),
)

consoleSpy.mockRestore()
})

it('Scenario: Log when using custom auth timeout', async () => {
// Given command line arguments with custom --auth-timeout value
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
const args = ['https://example.com/sse', '--auth-timeout', '45']
const usage = 'test usage'

// When parsing the command line arguments
const result = await parseCommandLineArgs(args, usage)

// Then the custom timeout should be used and logged
expect(result.authTimeoutMs).toBe(45000)
expect(consoleSpy).toHaveBeenCalledWith(expect.stringContaining('Using auth callback timeout: 45 seconds'))

consoleSpy.mockRestore()
})
})

describe('Feature: Tool Filtering with Ignore Patterns', () => {
Expand Down Expand Up @@ -773,3 +869,49 @@ describe('Feature: MCP Proxy', () => {
)
})
})

describe('setupOAuthCallbackServerWithLongPoll', () => {
let server: any
let events: EventEmitter

beforeEach(() => {
events = new EventEmitter()
})

afterEach(() => {
if (server) {
server.close()
server = null
}
})

it('should use custom timeout when authTimeoutMs is provided', async () => {
const customTimeout = 5000
const result = setupOAuthCallbackServerWithLongPoll({
port: 0, // Use any available port
path: '/oauth/callback',
events,
authTimeoutMs: customTimeout,
})

server = result.server

// Test that the server was created
expect(server).toBeDefined()
expect(typeof result.waitForAuthCode).toBe('function')
})

it('should use default timeout when authTimeoutMs is not provided', async () => {
const result = setupOAuthCallbackServerWithLongPoll({
port: 0, // Use any available port
path: '/oauth/callback',
events,
})

server = result.server

// Test that the server was created with defaults
expect(server).toBeDefined()
expect(typeof result.waitForAuthCode).toBe('function')
})
})
16 changes: 15 additions & 1 deletion src/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServe
const longPollTimeout = setTimeout(() => {
log('Long poll timeout reached, responding with 202')
res.status(202).send('Authentication in progress')
}, 30000)
}, options.authTimeoutMs || 30000)

// If auth completes while we're waiting, send the response immediately
authCompletedPromise
Expand Down Expand Up @@ -716,6 +716,19 @@ export async function parseCommandLineArgs(args: string[], usage: string) {
j++
}

// Parse auth timeout
let authTimeoutMs = 30000 // Default 30 seconds
const authTimeoutIndex = args.indexOf('--auth-timeout')
if (authTimeoutIndex !== -1 && authTimeoutIndex < args.length - 1) {
const timeoutSeconds = parseInt(args[authTimeoutIndex + 1], 10)
if (!isNaN(timeoutSeconds) && timeoutSeconds > 0) {
authTimeoutMs = timeoutSeconds * 1000
log(`Using auth callback timeout: ${timeoutSeconds} seconds`)
} else {
log(`Warning: Ignoring invalid auth timeout value: ${args[authTimeoutIndex + 1]}. Must be a positive number.`)
}
}

if (!serverUrl) {
log(usage)
process.exit(1)
Expand Down Expand Up @@ -791,6 +804,7 @@ export async function parseCommandLineArgs(args: string[], usage: string) {
staticOAuthClientInfo,
authorizeResource,
ignoredTools,
authTimeoutMs,
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async function runProxy(
staticOAuthClientInfo: StaticOAuthClientInformationFull,
authorizeResource: string,
ignoredTools: string[],
authTimeoutMs: number,
) {
// Set up event emitter for auth flow
const events = new EventEmitter()
Expand All @@ -45,7 +46,7 @@ async function runProxy(
const serverUrlHash = getServerUrlHash(serverUrl)

// Create a lazy auth coordinator
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs)

// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
Expand Down Expand Up @@ -158,6 +159,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts <https://se
staticOAuthClientInfo,
authorizeResource,
ignoredTools,
authTimeoutMs,
}) => {
return runProxy(
serverUrl,
Expand All @@ -169,6 +171,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts <https://se
staticOAuthClientInfo,
authorizeResource,
ignoredTools,
authTimeoutMs,
)
},
)
Expand Down
Loading