diff --git a/README.md b/README.md index 8cce8d6..bd3cc09 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,24 @@ To bypass authentication, or to emit custom headers on all requests to your remo ] ``` +* To ignore specific tools from the remote server, add the `--ignore-tool` flag. This will filter out tools matching the specified patterns from both `tools/list` responses and block `tools/call` requests. Supports wildcard patterns with `*`. + +```json + "args": [ + "mcp-remote", + "https://remote.mcp.server/sse", + "--ignore-tool", + "delete*", + "--ignore-tool", + "remove*" + ] +``` + +You can specify multiple `--ignore-tool` flags to ignore different patterns. Examples: +- `delete*` - ignores all tools starting with "delete" (e.g., `deleteTask`, `deleteUser`) +- `*account` - ignores all tools ending with "account" (e.g., `getAccount`, `updateAccount`) +- `exactTool` - ignores only the tool named exactly "exactTool" + ### Transport Strategies MCP Remote supports different transport strategies when connecting to an MCP server. This allows you to control whether it uses Server-Sent Events (SSE) or HTTP transport, and in what order it tries them. diff --git a/src/lib/utils.test.ts b/src/lib/utils.test.ts new file mode 100644 index 0000000..1a6d9a8 --- /dev/null +++ b/src/lib/utils.test.ts @@ -0,0 +1,775 @@ +import { describe, it, expect, vi } from 'vitest' +import { parseCommandLineArgs, shouldIncludeTool, mcpProxy } from './utils' +import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' + +// All sanitizeUrl tests have been moved to the strict-url-sanitise package + +describe('Feature: Command Line Arguments Parsing', () => { + it('Scenario: Parse basic server URL', async () => { + // Given command line arguments with only a server URL + const args = ['https://example.com/sse'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the server URL should be correctly extracted + expect(result.serverUrl).toBe('https://example.com/sse') + expect(typeof result.serverUrl).toBe('string') + }) + + it('Scenario: Parse server URL with callback port', async () => { + // Given command line arguments with server URL and port + const args = ['https://example.com/sse', '3000'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then both server URL and callback port should be correctly extracted + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.callbackPort).toBe(3000) + }) + + it('Scenario: Parse localhost URL with HTTP protocol', async () => { + // Given command line arguments with localhost HTTP URL + const args = ['http://localhost:8080/sse'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the localhost HTTP URL should be accepted + expect(result.serverUrl).toBe('http://localhost:8080/sse') + }) + + it('Scenario: Parse 127.0.0.1 URL with HTTP protocol', async () => { + // Given command line arguments with 127.0.0.1 HTTP URL + const args = ['http://127.0.0.1:8080/sse'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the 127.0.0.1 HTTP URL should be accepted + expect(result.serverUrl).toBe('http://127.0.0.1:8080/sse') + }) + + it('Scenario: Parse single custom header', async () => { + // Given command line arguments with a custom header + const args = ['https://example.com/sse', '--header', 'foo: taz'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the custom header should be correctly parsed + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.headers).toEqual({ foo: 'taz' }) + }) + + it('Scenario: Parse multiple custom headers', async () => { + // Given command line arguments with multiple custom headers + const args = ['https://example.com/sse', '--header', 'Authorization: Bearer token123', '--header', 'Content-Type: application/json'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then all custom headers should be correctly parsed + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.headers).toEqual({ + Authorization: 'Bearer token123', + 'Content-Type': 'application/json', + }) + }) + + it('Scenario: Ignore invalid header format', async () => { + // Given command line arguments with an invalid header format + const args = ['https://example.com/sse', '--header', 'invalid-header-format'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the invalid header should be ignored and headers should be empty + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.headers).toEqual({}) + }) + + it('Scenario: Handle --allow-http flag for non-localhost URLs', async () => { + // Given command line arguments with HTTP URL and --allow-http flag + const args = ['http://example.com/sse', '--allow-http'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the HTTP URL should be accepted due to --allow-http flag + expect(result.serverUrl).toBe('http://example.com/sse') + }) + + it('Scenario: Accept HTTPS URLs without --allow-http flag', async () => { + // Given command line arguments with HTTPS URL only + const args = ['https://example.com/sse'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the HTTPS URL should be accepted without any additional flags + expect(result.serverUrl).toBe('https://example.com/sse') + }) + + it('Scenario: Handle --allow-http with other arguments', async () => { + // Given command line arguments with HTTP URL, port, --allow-http flag, and custom header + const args = ['http://example.com/sse', '4000', '--allow-http', '--header', 'Authorization: Bearer abc123'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then all arguments should be correctly parsed including HTTP URL acceptance + expect(result.serverUrl).toBe('http://example.com/sse') + expect(result.callbackPort).toBe(4000) + expect(result.headers).toEqual({ Authorization: 'Bearer abc123' }) + }) + + it('Scenario: Use default transport strategy when not specified', async () => { + // Given command line arguments with only server URL + const args = ['https://example.com/sse'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the default transport strategy should be http-first + expect(result.transportStrategy).toBe('http-first') + }) + + it('Scenario: Parse transport strategy sse-only', async () => { + // Given command line arguments with --transport sse-only + const args = ['https://example.com/sse', '--transport', 'sse-only'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the transport strategy should be set to sse-only + expect(result.transportStrategy).toBe('sse-only') + }) + + it('Scenario: Parse transport strategy http-only', async () => { + // Given command line arguments with --transport http-only + const args = ['https://example.com/sse', '--transport', 'http-only'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the transport strategy should be set to http-only + expect(result.transportStrategy).toBe('http-only') + }) + + it('Scenario: Parse transport strategy sse-first', async () => { + // Given command line arguments with --transport sse-first + const args = ['https://example.com/sse', '--transport', 'sse-first'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the transport strategy should be set to sse-first + expect(result.transportStrategy).toBe('sse-first') + }) + + it('Scenario: Parse transport strategy http-first', async () => { + // Given command line arguments with --transport http-first + const args = ['https://example.com/sse', '--transport', 'http-first'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the transport strategy should be set to http-first + expect(result.transportStrategy).toBe('http-first') + }) + + it('Scenario: Ignore invalid transport strategy and use default', async () => { + // Given command line arguments with invalid transport strategy + const args = ['https://example.com/sse', '--transport', 'invalid-strategy'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the invalid strategy should be ignored and default should be used + expect(result.transportStrategy).toBe('http-first') // Should fallback to default + }) + + it('Scenario: Use default host when not specified', async () => { + // Given command line arguments with only server URL + const args = ['https://example.com/sse'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the default host should be localhost + expect(result.host).toBe('localhost') + }) + + it('Scenario: Parse custom IP host', async () => { + // Given command line arguments with custom IP host + const args = ['https://example.com/sse', '--host', '127.0.0.1'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the custom IP host should be correctly set + expect(result.host).toBe('127.0.0.1') + }) + + it('Scenario: Parse custom domain host', async () => { + // Given command line arguments with custom domain host + const args = ['https://example.com/sse', '--host', 'myserver.local'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the custom domain host should be correctly set + expect(result.host).toBe('myserver.local') + }) + + it('Scenario: Handle host with multiple other arguments', async () => { + // Given command line arguments with host, port, and transport strategy + const args = ['https://example.com/sse', '3000', '--host', 'custom.host.com', '--transport', 'sse-only'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then all arguments should be correctly parsed including the host + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.callbackPort).toBe(3000) + expect(result.host).toBe('custom.host.com') + expect(result.transportStrategy).toBe('sse-only') + }) + + it('Scenario: Return empty ignored tools array when none specified', async () => { + // Given command line arguments without --ignore-tool flags + const args = ['https://example.com/sse'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the ignored tools array should be empty + expect(result.ignoredTools).toEqual([]) + }) + + it('Scenario: Parse single ignored tool', async () => { + // Given command line arguments with one --ignore-tool flag + const args = ['https://example.com/sse', '--ignore-tool', 'foo'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the ignored tools array should contain the specified tool + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.ignoredTools).toEqual(['foo']) + }) + + it('Scenario: Parse multiple ignored tools', async () => { + // Given command line arguments with multiple --ignore-tool flags + const args = ['https://example.com/sse', '--ignore-tool', 'foo', '--ignore-tool', 'bar', '--ignore-tool', 'baz'] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then the ignored tools array should contain all specified tools + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.ignoredTools).toEqual(['foo', 'bar', 'baz']) + }) + + it('Scenario: Handle ignored tools with other arguments', async () => { + // Given command line arguments with ignored tools mixed with other arguments + const args = [ + 'https://example.com/sse', + '4000', + '--ignore-tool', + 'tool1', + '--host', + 'localhost', + '--ignore-tool', + 'tool2', + '--transport', + 'sse-only', + ] + const usage = 'test usage' + + // When parsing the command line arguments + const result = await parseCommandLineArgs(args, usage) + + // Then all arguments should be correctly parsed including ignored tools + expect(result.serverUrl).toBe('https://example.com/sse') + expect(result.callbackPort).toBe(4000) + expect(result.host).toBe('localhost') + expect(result.transportStrategy).toBe('sse-only') + expect(result.ignoredTools).toEqual(['tool1', 'tool2']) + }) +}) + +describe('Feature: Tool Filtering with Ignore Patterns', () => { + it('Scenario: Single wildcard pattern ignores matching tools', () => { + // Given ignore patterns with create* wildcard + const ignorePatterns = ['create*'] + + // When checking if createTask should be included + const result1 = shouldIncludeTool(ignorePatterns, 'createTask') + // Then it should be excluded (return false) + expect(result1).toBe(false) + + // When checking if getTask should be included + const result2 = shouldIncludeTool(ignorePatterns, 'getTask') + // Then it should be included (return true) + expect(result2).toBe(true) + }) + + it('Scenario: Multiple wildcard patterns ignore matching tools', () => { + // Given ignore patterns with create* and put* wildcards + const ignorePatterns = ['create*', 'put*'] + + // When checking if createTask should be included + const result1 = shouldIncludeTool(ignorePatterns, 'createTask') + // Then it should be excluded (return false) + expect(result1).toBe(false) + + // When checking if infoTask should be included + const result2 = shouldIncludeTool(ignorePatterns, 'infoTask') + // Then it should be included (return true) + expect(result2).toBe(true) + }) + + it('Scenario: Suffix wildcard pattern ignores matching tools', () => { + // Given ignore patterns with *account suffix wildcard + const ignorePatterns = ['*account'] + + // When checking various account-related tools + const result1 = shouldIncludeTool(ignorePatterns, 'getAccount') + const result2 = shouldIncludeTool(ignorePatterns, 'putAccount') + const result3 = shouldIncludeTool(ignorePatterns, 'account') + + // Then all should be excluded (return false) + expect(result1).toBe(false) + expect(result2).toBe(false) + expect(result3).toBe(false) + }) + + it('Scenario: Empty ignore patterns include all tools', () => { + // Given empty ignore patterns + const ignorePatterns: string[] = [] + + // When checking any tool + const result = shouldIncludeTool(ignorePatterns, 'anyTool') + + // Then it should be included (return true) + expect(result).toBe(true) + }) + + it('Scenario: Non-matching patterns include tools', () => { + // Given ignore patterns that don't match the tool + const ignorePatterns = ['delete*', 'remove*'] + + // When checking a tool that doesn't match any pattern + const result = shouldIncludeTool(ignorePatterns, 'createTask') + + // Then it should be included (return true) + expect(result).toBe(true) + }) + + it('Scenario: Exact match without wildcards', () => { + // Given ignore patterns with exact tool names + const ignorePatterns = ['exactTool', 'anotherTool'] + + // When checking the exact tool name + const result1 = shouldIncludeTool(ignorePatterns, 'exactTool') + // Then it should be excluded (return false) + expect(result1).toBe(false) + + // When checking a different tool name + const result2 = shouldIncludeTool(ignorePatterns, 'differentTool') + // Then it should be included (return true) + expect(result2).toBe(true) + }) +}) + +describe('Feature: MCP Proxy', () => { + it('Scenario: Proxy initialize message from client to server', async () => { + // Given mock transports for client and server + const mockTransportToClient = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + const mockTransportToServer = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + // When setting up the proxy + mcpProxy({ + transportToClient: mockTransportToClient, + transportToServer: mockTransportToServer, + ignoredTools: [], + }) + + // And when client sends an initialize message + const initializeMessage = { + jsonrpc: '2.0' as const, + method: 'initialize', + id: '1', + params: { + clientInfo: { + name: 'Test Client', + version: '1.0.0', + }, + }, + } + + // Simulate client sending a message by calling the message handler directly + if (mockTransportToClient.onmessage) { + mockTransportToClient.onmessage(initializeMessage) + } + + // Then the message should be forwarded to the server + expect(mockTransportToServer.send).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + method: 'initialize', + id: '1', + params: expect.objectContaining({ + clientInfo: expect.objectContaining({ + name: expect.stringContaining('Test Client'), + version: '1.0.0', + }), + }), + }), + ) + }) + + it('Scenario: Proxy server response back to client', async () => { + // Given mock transports for client and server + const mockTransportToClient = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + const mockTransportToServer = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + // When setting up the proxy + mcpProxy({ + transportToClient: mockTransportToClient, + transportToServer: mockTransportToServer, + ignoredTools: [], + }) + + // First simulate client sending a request (so there's a pending request) + const clientRequest = { + jsonrpc: '2.0' as const, + method: 'initialize', + id: '1', + params: { + clientInfo: { + name: 'Test Client', + version: '1.0.0', + }, + }, + } + + if (mockTransportToClient.onmessage) { + mockTransportToClient.onmessage(clientRequest) + } + + // Clear the previous call + vi.clearAllMocks() + + // Now simulate server sending a response message + const serverResponse = { + jsonrpc: '2.0' as const, + id: '1', + result: { + capabilities: { + tools: { + listChanged: true, + }, + }, + serverInfo: { + name: 'Atlassian MCP', + version: '1.0.0', + }, + }, + } + + // Simulate server sending a response by calling the message handler directly + if (mockTransportToServer.onmessage) { + mockTransportToServer.onmessage(serverResponse) + } + + // Then the response should be forwarded to the client + expect(mockTransportToClient.send).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: '1', + result: { + capabilities: { + tools: { + listChanged: true, + }, + }, + serverInfo: { + name: 'Atlassian MCP', + version: '1.0.0', + }, + }, + }), + ) + }) + + it('Scenario: Close server transport when client transport closes', async () => { + // Given mock transports for client and server + const mockTransportToClient = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + const mockTransportToServer = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + // When setting up the proxy + mcpProxy({ + transportToClient: mockTransportToClient, + transportToServer: mockTransportToServer, + ignoredTools: [], + }) + + // And when client transport closes + if (mockTransportToClient.onclose) { + mockTransportToClient.onclose() + } + + // Then server transport should also be closed + expect(mockTransportToServer.close).toHaveBeenCalled() + }) + + it('Scenario: Close client transport when server transport closes', async () => { + // Given mock transports for client and server + const mockTransportToClient = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + const mockTransportToServer = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + // When setting up the proxy + mcpProxy({ + transportToClient: mockTransportToClient, + transportToServer: mockTransportToServer, + ignoredTools: [], + }) + + // And when server transport closes + if (mockTransportToServer.onclose) { + mockTransportToServer.onclose() + } + + // Then client transport should also be closed + expect(mockTransportToClient.close).toHaveBeenCalled() + }) + + it('Scenario: Filter tools in tools/list response when ignoredTools is configured', async () => { + // Given mock transports for client and server + const mockTransportToClient = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + const mockTransportToServer = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + // When setting up the proxy with ignored tools + mcpProxy({ + transportToClient: mockTransportToClient, + transportToServer: mockTransportToServer, + ignoredTools: ['delete*', 'remove*'], + }) + + // First simulate client sending a tools/list request + const toolsListRequest = { + jsonrpc: '2.0' as const, + method: 'tools/list', + id: '2', + params: {}, + } + + if (mockTransportToClient.onmessage) { + mockTransportToClient.onmessage(toolsListRequest) + } + + // Clear the previous call + vi.clearAllMocks() + + // Now simulate server sending a tools/list response with various tools + const serverToolsResponse = { + jsonrpc: '2.0' as const, + id: '2', + result: { + tools: [ + { name: 'createTask', description: 'Create a new task' }, + { name: 'deleteTask', description: 'Delete a task' }, + { name: 'updateTask', description: 'Update a task' }, + { name: 'removeUser', description: 'Remove a user' }, + { name: 'listTasks', description: 'List all tasks' }, + ], + }, + } + + // Simulate server sending a response + if (mockTransportToServer.onmessage) { + mockTransportToServer.onmessage(serverToolsResponse) + } + + // Then the response should be forwarded to the client with filtered tools + expect(mockTransportToClient.send).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: '2', + result: { + tools: [ + { name: 'createTask', description: 'Create a new task' }, + { name: 'updateTask', description: 'Update a task' }, + { name: 'listTasks', description: 'List all tasks' }, + ], + }, + }), + ) + }) + + it('Scenario: Block tools/call for ignored tools with delete* filter', async () => { + // Given mock transports for client and server + const mockTransportToClient = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + const mockTransportToServer = { + send: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + start: vi.fn().mockResolvedValue(undefined), + onmessage: vi.fn(), + onclose: vi.fn(), + onerror: vi.fn(), + } as unknown as Transport + + // When setting up the proxy with delete* filter + mcpProxy({ + transportToClient: mockTransportToClient, + transportToServer: mockTransportToServer, + ignoredTools: ['delete*'], + }) + + // And when client tries to call a deleteTask tool + const toolsCallMessage = { + jsonrpc: '2.0' as const, + method: 'tools/call', + id: '3', + params: { + name: 'deleteTask', + arguments: { + taskId: '1', + }, + _meta: { + progressToken: 1, + }, + }, + } + + // Simulate client sending the tools/call message + if (mockTransportToClient.onmessage) { + mockTransportToClient.onmessage(toolsCallMessage) + } + + // Then the call should NOT be forwarded to the server + expect(mockTransportToServer.send).not.toHaveBeenCalled() + + // And an error response should be sent back to the client + expect(mockTransportToClient.send).toHaveBeenCalledWith( + expect.objectContaining({ + jsonrpc: '2.0', + id: '3', + error: expect.objectContaining({ + code: expect.any(Number), + message: expect.stringContaining('Tool "deleteTask" is not available'), + }), + }), + ) + }) +}) diff --git a/src/lib/utils.ts b/src/lib/utils.ts index 6d01edf..cfafe5c 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -80,17 +80,101 @@ export function log(str: string, ...rest: unknown[]) { } } +type Message = any +const MESSAGE_BLOCKED = Symbol('MessageBlocked') +const isMessageBlocked = (value: any): value is typeof MESSAGE_BLOCKED => value === MESSAGE_BLOCKED + +export function createMessageTransformer({ + transformRequestFunction, + transformResponseFunction, +}: { + transformRequestFunction?: null | ((request: Message) => Message | typeof MESSAGE_BLOCKED) + transformResponseFunction?: null | ((request: Message, response: Message) => Message) +} = {}) { + const pendingRequests = new Map() + + const interceptRequest = (message: Message) => { + const messageId = message.id + if (!messageId) return message + pendingRequests.set(messageId, message) + return transformRequestFunction?.(message) ?? message + } + + const interceptResponse = (message: Message) => { + const messageId = message.id + if (!messageId) return message + const originalRequest = pendingRequests.get(messageId) + pendingRequests.delete(messageId) + return transformResponseFunction?.(originalRequest, message) ?? message + } + + return { + interceptRequest, + interceptResponse, + } +} + /** * Creates a bidirectional proxy between two transports * @param params The transport connections to proxy between */ -export function mcpProxy({ transportToClient, transportToServer }: { transportToClient: Transport; transportToServer: Transport }) { +export function mcpProxy({ + transportToClient, + transportToServer, + ignoredTools = [], +}: { + transportToClient: Transport + transportToServer: Transport + ignoredTools?: string[] +}) { let transportToClientClosed = false let transportToServerClosed = false + const messageTransformer = createMessageTransformer({ + transformRequestFunction: (request: Message) => { + // Block tools/call for ignored tools + if (request.method === 'tools/call' && request.params?.name) { + const toolName = request.params.name + if (!shouldIncludeTool(ignoredTools, toolName)) { + // Send error response back to client immediately + const errorResponse = { + jsonrpc: '2.0' as const, + id: request.id, + error: { + code: -32603, + message: `Tool "${toolName}" is not available`, + }, + } + transportToClient.send(errorResponse).catch(onClientError) + // Return symbol to indicate this request should not be forwarded + return MESSAGE_BLOCKED + } + } + return request + }, + transformResponseFunction: (req: Message, res: Message) => { + if (req.method === 'tools/list') { + return { + ...res, + result: { + ...res.result, + tools: res.result.tools.filter((tool: any) => shouldIncludeTool(ignoredTools, tool.name)), + }, + } + } + return res + }, + }) + transportToClient.onmessage = (_message) => { // TODO: fix types - const message = _message as any + const message = messageTransformer.interceptRequest(_message as any) + + // If interceptor returns MESSAGE_BLOCKED, don't forward the message + if (isMessageBlocked(message)) { + return + } + log('[Local→Remote]', message.method || message.id) if (DEBUG) { @@ -116,7 +200,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo transportToServer.onmessage = (_message) => { // TODO: fix types - const message = _message as any + const message = messageTransformer.interceptResponse(_message as any) log('[Remote→Local]', message.method || message.id) if (DEBUG) { @@ -617,6 +701,21 @@ export async function parseCommandLineArgs(args: string[], usage: string) { log(`Using authorize resource: ${authorizeResource}`) } + // Parse ignored tools + const ignoredTools: string[] = [] + let j = 0 + while (j < args.length) { + if (args[j] === '--ignore-tool' && j < args.length - 1) { + const toolName = args[j + 1] + ignoredTools.push(toolName) + log(`Ignoring tool: ${toolName}`) + args.splice(j, 2) + // Do not increment j, as the array has shifted + continue + } + j++ + } + if (!serverUrl) { log(usage) process.exit(1) @@ -691,6 +790,7 @@ export async function parseCommandLineArgs(args: string[], usage: string) { staticOAuthClientMetadata, staticOAuthClientInfo, authorizeResource, + ignoredTools, } } @@ -722,3 +822,40 @@ export function setupSignalHandlers(cleanup: () => Promise) { export function getServerUrlHash(serverUrl: string): string { return crypto.createHash('md5').update(serverUrl).digest('hex') } + +/** + * Converts a glob pattern to a regular expression + * @param pattern The glob pattern (e.g., "create*", "*account") + * @returns The corresponding regular expression + */ +function patternToRegex(pattern: string): RegExp { + // Split by asterisks, escape each part, then join with .* + const parts = pattern.split('*') + const escapedParts = parts.map((part) => part.replace(/\W/g, '\\$&')) + const regexPattern = escapedParts.join('.*') + // Match the entire string from start to end, case-insensitive + return new RegExp(`^${regexPattern}$`, 'i') +} + +/** + * Determines if a tool name should be ignored based on ignore patterns + * @param ignorePatterns Array of patterns to ignore (supports wildcards with *) + * @param toolName The name of the tool to check + * @returns false if the tool should be ignored (matches a pattern), true if it should be included + */ +export function shouldIncludeTool(ignorePatterns: string[], toolName: string): boolean { + // If no patterns are provided, include all tools + if (!ignorePatterns || ignorePatterns.length === 0) { + return true + } + + // Check if the tool name matches any ignore pattern + for (const pattern of ignorePatterns) { + const regex = patternToRegex(pattern) + if (regex.test(toolName)) { + return false // Tool matches an ignore pattern, so exclude it + } + } + + return true // Tool doesn't match any ignore pattern, so include it +} diff --git a/src/proxy.ts b/src/proxy.ts index 6627bf8..d23ed95 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -36,6 +36,7 @@ async function runProxy( staticOAuthClientMetadata: StaticOAuthClientMetadata, staticOAuthClientInfo: StaticOAuthClientInformationFull, authorizeResource: string, + ignoredTools: string[], ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -92,6 +93,7 @@ async function runProxy( mcpProxy({ transportToClient: localTransport, transportToServer: remoteTransport, + ignoredTools, }) // Start the local STDIO server @@ -155,6 +157,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts { return runProxy( serverUrl, @@ -165,6 +168,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts