Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export const groupOptionsSchema = z.object({
{ message: "Invalid regular expression pattern" },
),
description: z.string().optional(),
allowedMcpServers: z.array(z.string()).optional(),
})

export type GroupOptions = z.infer<typeof groupOptionsSchema>
Expand Down
61 changes: 61 additions & 0 deletions src/core/config/__tests__/ModeConfig.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,67 @@ describe("CustomModeSchema", () => {
})
})

describe("allowedMcpServers", () => {
it("validates a mode with MCP server restrictions", () => {
const modeWithMcpRestrictions = {
slug: "weather-mode",
name: "Weather Mode",
roleDefinition: "Weather analysis mode",
groups: ["read", ["mcp", { allowedMcpServers: ["weather-server", "climate-server"] }]],
}

const modeWithDescription = {
slug: "data-mode",
name: "Data Mode",
roleDefinition: "Data analysis mode",
groups: [
"read",
["mcp", { allowedMcpServers: ["database-server"], description: "Database server only" }],
],
}

expect(() => modeConfigSchema.parse(modeWithMcpRestrictions)).not.toThrow()
expect(() => modeConfigSchema.parse(modeWithDescription)).not.toThrow()
})

it("accepts empty allowedMcpServers array", () => {
const modeWithEmptyMcpList = {
slug: "no-mcp-mode",
name: "No MCP Mode",
roleDefinition: "Mode without MCP access",
groups: ["read", ["mcp", { allowedMcpServers: [] }]],
}

expect(() => modeConfigSchema.parse(modeWithEmptyMcpList)).not.toThrow()
})

it("validates that allowedMcpServers contains only strings", () => {
const modeWithInvalidMcpList = {
slug: "invalid-mode",
name: "Invalid Mode",
roleDefinition: "Invalid mode",
groups: ["read", ["mcp", { allowedMcpServers: ["valid-server", 123, "another-server"] }]],
}

expect(() => modeConfigSchema.parse(modeWithInvalidMcpList)).toThrow()
})

it("allows combining fileRegex and allowedMcpServers in different groups", () => {
const modeWithBothRestrictions = {
slug: "restricted-mode",
name: "Restricted Mode",
roleDefinition: "Mode with multiple restrictions",
groups: [
"read",
["edit", { fileRegex: "\\.md$", description: "Markdown files only" }],
["mcp", { allowedMcpServers: ["weather-server"], description: "Weather server only" }],
],
}

expect(() => modeConfigSchema.parse(modeWithBothRestrictions)).not.toThrow()
})
})

const validBaseMode = {
slug: "123e4567-e89b-12d3-a456-426614174000",
name: "Test Mode",
Expand Down
167 changes: 166 additions & 1 deletion src/shared/__tests__/modes.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@ vi.mock("../../core/prompts/sections/custom-instructions", () => ({
addCustomInstructions: vi.fn().mockResolvedValue("Combined instructions"),
}))

import { isToolAllowedForMode, FileRestrictionError, getFullModeDetails, modes, getModeSelection } from "../modes"
import {
isToolAllowedForMode,
FileRestrictionError,
McpServerRestrictionError,
getFullModeDetails,
modes,
getModeSelection,
} from "../modes"
import { addCustomInstructions } from "../../core/prompts/sections/custom-instructions"

describe("isToolAllowedForMode", () => {
Expand Down Expand Up @@ -247,6 +254,164 @@ describe("isToolAllowedForMode", () => {
})
})

describe("MCP server restrictions", () => {
const customModesWithMcpRestrictions: ModeConfig[] = [
{
slug: "restricted-mcp-mode",
name: "Restricted MCP Mode",
roleDefinition: "You are a mode with MCP restrictions",
groups: [
"read",
[
"mcp",
{
allowedMcpServers: ["weather-server", "file-server"],
description: "Weather and file servers only",
},
],
],
},
{
slug: "unrestricted-mcp-mode",
name: "Unrestricted MCP Mode",
roleDefinition: "You are a mode without MCP restrictions",
groups: ["read", "mcp"],
},
]

it("allows using permitted MCP servers", () => {
// Should allow weather-server
expect(
isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, {
server_name: "weather-server",
tool_name: "get_forecast",
arguments: '{"city": "San Francisco"}',
}),
).toBe(true)

// Should allow file-server
expect(
isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, {
server_name: "file-server",
tool_name: "read_file",
arguments: '{"path": "test.txt"}',
}),
).toBe(true)

// Should allow access_mcp_resource with permitted servers
expect(
isToolAllowedForMode(
"access_mcp_resource",
"restricted-mcp-mode",
customModesWithMcpRestrictions,
undefined,
{
server_name: "weather-server",
uri: "weather://current",
},
),
).toBe(true)
})

it("rejects using non-permitted MCP servers", () => {
// Should reject database-server
expect(() =>
isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, {
server_name: "database-server",
tool_name: "query",
arguments: '{"sql": "SELECT * FROM users"}',
}),
).toThrow(McpServerRestrictionError)
expect(() =>
isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, {
server_name: "database-server",
tool_name: "query",
arguments: '{"sql": "SELECT * FROM users"}',
}),
).toThrow(/weather-server, file-server/)
expect(() =>
isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, {
server_name: "database-server",
tool_name: "query",
arguments: '{"sql": "SELECT * FROM users"}',
}),
).toThrow(/Weather and file servers only/)

// Should reject access_mcp_resource with non-permitted servers
expect(() =>
isToolAllowedForMode(
"access_mcp_resource",
"restricted-mcp-mode",
customModesWithMcpRestrictions,
undefined,
{
server_name: "database-server",
uri: "database://users",
},
),
).toThrow(McpServerRestrictionError)
})

it("allows unrestricted modes to use any MCP server", () => {
// Should allow any server for unrestricted mode
expect(
isToolAllowedForMode(
"use_mcp_tool",
"unrestricted-mcp-mode",
customModesWithMcpRestrictions,
undefined,
{
server_name: "any-server",
tool_name: "any_tool",
arguments: "{}",
},
),
).toBe(true)

expect(
isToolAllowedForMode(
"access_mcp_resource",
"unrestricted-mcp-mode",
customModesWithMcpRestrictions,
undefined,
{
server_name: "any-server",
uri: "any://resource",
},
),
).toBe(true)
})

it("allows MCP tools without server_name parameter for partial requests", () => {
// Should allow partial MCP tool requests without server_name (streaming scenarios)
expect(
isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, {
tool_name: "get_forecast",
}),
).toBe(true)

expect(
isToolAllowedForMode(
"use_mcp_tool",
"restricted-mcp-mode",
customModesWithMcpRestrictions,
undefined,
{},
),
).toBe(true)
})

it("uses description in MCP server restriction error", () => {
expect(() =>
isToolAllowedForMode("use_mcp_tool", "restricted-mcp-mode", customModesWithMcpRestrictions, undefined, {
server_name: "forbidden-server",
tool_name: "test",
arguments: "{}",
}),
).toThrow(/Weather and file servers only/)
})
})

it("handles non-existent modes", () => {
expect(isToolAllowedForMode("write_to_file", "non-existent", customModes)).toBe(false)
})
Expand Down
23 changes: 23 additions & 0 deletions src/shared/modes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ export class FileRestrictionError extends Error {
}
}

// Custom error class for MCP server restrictions
export class McpServerRestrictionError extends Error {
constructor(mode: string, allowedServers: string[], description: string | undefined, serverName: string) {
super(
`This mode (${mode}) can only use MCP servers: ${allowedServers.join(", ")}${description ? ` (${description})` : ""}. Got: ${serverName}`,
)
this.name = "McpServerRestrictionError"
}
}

export function isToolAllowedForMode(
tool: string,
modeSlug: string,
Expand Down Expand Up @@ -267,6 +277,19 @@ export function isToolAllowedForMode(
}
}

// For the mcp group, check allowed MCP servers if specified
if (groupName === "mcp" && options.allowedMcpServers) {
const serverName = toolParams?.server_name
if (serverName && !options.allowedMcpServers.includes(serverName)) {
throw new McpServerRestrictionError(
mode.name,
options.allowedMcpServers,
options.description,
serverName,
)
}
}

return true
}

Expand Down