diff --git a/packages/types/src/mode.ts b/packages/types/src/mode.ts index dfe95f8d7e..db9836828c 100644 --- a/packages/types/src/mode.ts +++ b/packages/types/src/mode.ts @@ -26,6 +26,7 @@ export const groupOptionsSchema = z.object({ { message: "Invalid regular expression pattern" }, ), description: z.string().optional(), + allowedMcpServers: z.array(z.string()).optional(), }) export type GroupOptions = z.infer diff --git a/src/core/config/__tests__/ModeConfig.spec.ts b/src/core/config/__tests__/ModeConfig.spec.ts index dbdd1a0f03..c8d0217c1c 100644 --- a/src/core/config/__tests__/ModeConfig.spec.ts +++ b/src/core/config/__tests__/ModeConfig.spec.ts @@ -176,6 +176,67 @@ describe("CustomModeSchema", () => { }) }) + describe("allowedMcpServers", () => { + it("validates a mode with MCP server restrictions", () => { + const modeWithMcpRestrictions = { + slug: "weather-mode", + name: "Weather Mode", + roleDefinition: "Weather analysis mode", + groups: ["read", ["mcp", { allowedMcpServers: ["weather-server", "climate-server"] }]], + } + + const modeWithDescription = { + slug: "data-mode", + name: "Data Mode", + roleDefinition: "Data analysis mode", + groups: [ + "read", + ["mcp", { allowedMcpServers: ["database-server"], description: "Database server only" }], + ], + } + + expect(() => modeConfigSchema.parse(modeWithMcpRestrictions)).not.toThrow() + expect(() => modeConfigSchema.parse(modeWithDescription)).not.toThrow() + }) + + it("accepts empty allowedMcpServers array", () => { + const modeWithEmptyMcpList = { + slug: "no-mcp-mode", + name: "No MCP Mode", + roleDefinition: "Mode without MCP access", + groups: ["read", ["mcp", { allowedMcpServers: [] }]], + } + + expect(() => modeConfigSchema.parse(modeWithEmptyMcpList)).not.toThrow() + }) + + it("validates that allowedMcpServers contains only strings", () => { + const modeWithInvalidMcpList = { + slug: "invalid-mode", + name: "Invalid Mode", + roleDefinition: "Invalid mode", + groups: ["read", ["mcp", { allowedMcpServers: ["valid-server", 123, "another-server"] }]], + } + + expect(() => modeConfigSchema.parse(modeWithInvalidMcpList)).toThrow() + }) + + it("allows combining fileRegex and allowedMcpServers in different groups", () => { + const modeWithBothRestrictions = { + slug: "restricted-mode", + name: "Restricted Mode", + roleDefinition: "Mode with multiple restrictions", + groups: [ + "read", + ["edit", { fileRegex: "\\.md$", description: "Markdown files only" }], + ["mcp", { allowedMcpServers: ["weather-server"], description: "Weather server only" }], + ], + } + + expect(() => modeConfigSchema.parse(modeWithBothRestrictions)).not.toThrow() + }) + }) + const validBaseMode = { slug: "123e4567-e89b-12d3-a456-426614174000", name: "Test Mode", diff --git a/src/shared/__tests__/modes.spec.ts b/src/shared/__tests__/modes.spec.ts index 8ca7eec150..a7bf829dcf 100644 --- a/src/shared/__tests__/modes.spec.ts +++ b/src/shared/__tests__/modes.spec.ts @@ -9,7 +9,14 @@ vi.mock("../../core/prompts/sections/custom-instructions", () => ({ addCustomInstructions: vi.fn().mockResolvedValue("Combined instructions"), })) -import { isToolAllowedForMode, FileRestrictionError, getFullModeDetails, modes, getModeSelection } from "../modes" +import { + isToolAllowedForMode, + FileRestrictionError, + McpServerRestrictionError, + getFullModeDetails, + modes, + getModeSelection, +} from "../modes" import { addCustomInstructions } from "../../core/prompts/sections/custom-instructions" describe("isToolAllowedForMode", () => { @@ -247,6 +254,164 @@ describe("isToolAllowedForMode", () => { }) }) + describe("MCP server restrictions", () => { + const customModesWithMcpRestrictions: ModeConfig[] = [ + { + slug: "restricted-mcp-mode", + name: "Restricted MCP Mode", + roleDefinition: "You are a mode with MCP restrictions", + groups: [ + "read", + [ + "mcp", + { + allowedMcpServers: ["weather-server", "file-server"], + description: "Weather and file servers only", + }, + ], + ], + }, + { + slug: "unrestricted-mcp-mode", + name: "Unrestricted MCP Mode", + roleDefinition: "You are a mode without MCP restrictions", + groups: ["read", "mcp"], + }, + ] + + it("allows using permitted MCP servers", () => { + // Should allow weather-server + expect( + isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, { + server_name: "weather-server", + tool_name: "get_forecast", + arguments: '{"city": "San Francisco"}', + }), + ).toBe(true) + + // Should allow file-server + expect( + isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, { + server_name: "file-server", + tool_name: "read_file", + arguments: '{"path": "test.txt"}', + }), + ).toBe(true) + + // Should allow access_mcp_resource with permitted servers + expect( + isToolAllowedForMode( + "access_mcp_resource", + "restricted-mcp-mode", + customModesWithMcpRestrictions, + undefined, + { + server_name: "weather-server", + uri: "weather://current", + }, + ), + ).toBe(true) + }) + + it("rejects using non-permitted MCP servers", () => { + // Should reject database-server + expect(() => + isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, { + server_name: "database-server", + tool_name: "query", + arguments: '{"sql": "SELECT * FROM users"}', + }), + ).toThrow(McpServerRestrictionError) + expect(() => + isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, { + server_name: "database-server", + tool_name: "query", + arguments: '{"sql": "SELECT * FROM users"}', + }), + ).toThrow(/weather-server, file-server/) + expect(() => + isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, { + server_name: "database-server", + tool_name: "query", + arguments: '{"sql": "SELECT * FROM users"}', + }), + ).toThrow(/Weather and file servers only/) + + // Should reject access_mcp_resource with non-permitted servers + expect(() => + isToolAllowedForMode( + "access_mcp_resource", + "restricted-mcp-mode", + customModesWithMcpRestrictions, + undefined, + { + server_name: "database-server", + uri: "database://users", + }, + ), + ).toThrow(McpServerRestrictionError) + }) + + it("allows unrestricted modes to use any MCP server", () => { + // Should allow any server for unrestricted mode + expect( + isToolAllowedForMode( + "use_mcp_tool", + "unrestricted-mcp-mode", + customModesWithMcpRestrictions, + undefined, + { + server_name: "any-server", + tool_name: "any_tool", + arguments: "{}", + }, + ), + ).toBe(true) + + expect( + isToolAllowedForMode( + "access_mcp_resource", + "unrestricted-mcp-mode", + customModesWithMcpRestrictions, + undefined, + { + server_name: "any-server", + uri: "any://resource", + }, + ), + ).toBe(true) + }) + + it("allows MCP tools without server_name parameter for partial requests", () => { + // Should allow partial MCP tool requests without server_name (streaming scenarios) + expect( + isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, { + tool_name: "get_forecast", + }), + ).toBe(true) + + expect( + isToolAllowedForMode( + "use_mcp_tool", + "restricted-mcp-mode", + customModesWithMcpRestrictions, + undefined, + {}, + ), + ).toBe(true) + }) + + it("uses description in MCP server restriction error", () => { + expect(() => + isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, { + server_name: "forbidden-server", + tool_name: "test", + arguments: "{}", + }), + ).toThrow(/Weather and file servers only/) + }) + }) + it("handles non-existent modes", () => { expect(isToolAllowedForMode("write_to_file", "non-existent", customModes)).toBe(false) }) diff --git a/src/shared/modes.ts b/src/shared/modes.ts index 56d41f3c73..baa6ecf057 100644 --- a/src/shared/modes.ts +++ b/src/shared/modes.ts @@ -205,6 +205,16 @@ export class FileRestrictionError extends Error { } } +// Custom error class for MCP server restrictions +export class McpServerRestrictionError extends Error { + constructor(mode: string, allowedServers: string[], description: string | undefined, serverName: string) { + super( + `This mode (${mode}) can only use MCP servers: ${allowedServers.join(", ")}${description ? ` (${description})` : ""}. Got: ${serverName}`, + ) + this.name = "McpServerRestrictionError" + } +} + export function isToolAllowedForMode( tool: string, modeSlug: string, @@ -267,6 +277,19 @@ export function isToolAllowedForMode( } } + // For the mcp group, check allowed MCP servers if specified + if (groupName === "mcp" && options.allowedMcpServers) { + const serverName = toolParams?.server_name + if (serverName && !options.allowedMcpServers.includes(serverName)) { + throw new McpServerRestrictionError( + mode.name, + options.allowedMcpServers, + options.description, + serverName, + ) + } + } + return true }