Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/types/src/mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ export const groupOptionsSchema = z.object({
{ message: "Invalid regular expression pattern" },
),
description: z.string().optional(),
mcp: z
.object({
included: z.array(z.string()),
description: z.string().optional(),
})
.optional(),
})

export type GroupOptions = z.infer<typeof groupOptionsSchema>
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 30 additions & 1 deletion src/core/prompts/__tests__/add-custom-instructions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,36 @@ const mockContext = {
// Instead of extending McpHub, create a mock that implements just what we need
const createMockMcpHub = (withServers: boolean = false): McpHub =>
({
getServers: () => (withServers ? [{ name: "test-server", disabled: false }] : []),
getServers: () =>
withServers
? [
{
name: "test-server",
disabled: false,
status: "connected",
config: JSON.stringify({ command: "test-command" }),
tools: [],
resourceTemplates: [],
resources: [],
instructions: undefined,
},
]
: [],
getAllServers: () =>
withServers
? [
{
name: "test-server",
disabled: false,
status: "connected",
config: JSON.stringify({ command: "test-command" }),
tools: [],
resourceTemplates: [],
resources: [],
instructions: undefined,
},
]
: [],
getMcpServersPath: async () => "/mock/mcp/path",
getMcpSettingsFilePath: async () => "/mock/settings/path",
dispose: async () => {},
Expand Down
31 changes: 30 additions & 1 deletion src/core/prompts/__tests__/system-prompt.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,36 @@ const mockContext = {
// Instead of extending McpHub, create a mock that implements just what we need
const createMockMcpHub = (withServers: boolean = false): McpHub =>
({
getServers: () => (withServers ? [{ name: "test-server", disabled: false }] : []),
getServers: () =>
withServers
? [
{
name: "test-server",
disabled: false,
status: "connected",
config: JSON.stringify({ command: "test-command" }),
tools: [],
resourceTemplates: [],
resources: [],
instructions: undefined,
},
]
: [],
getAllServers: () =>
withServers
? [
{
name: "test-server",
disabled: false,
status: "connected",
config: JSON.stringify({ command: "test-command" }),
tools: [],
resourceTemplates: [],
resources: [],
instructions: undefined,
},
]
: [],
getMcpServersPath: async () => "/mock/mcp/path",
getMcpSettingsFilePath: async () => "/mock/settings/path",
dispose: async () => {},
Expand Down
152 changes: 116 additions & 36 deletions src/core/prompts/sections/mcp-servers.ts
Original file line number Diff line number Diff line change
@@ -1,53 +1,133 @@
import { DiffStrategy } from "../../../shared/tools"
import { McpHub } from "../../../services/mcp/McpHub"
import { GroupEntry, ModeConfig } from "@roo-code/types"
import { getGroupName } from "../../../shared/modes"
import { McpServer } from "../../../shared/mcp"

let lastMcpHub: McpHub | undefined
let lastMcpIncludedList: string[] | undefined
let lastFilteredServers: McpServer[] = []

function memoizeFilteredServers(mcpHub: McpHub, mcpIncludedList?: string[]): McpServer[] {
const mcpHubChanged = mcpHub !== lastMcpHub
const listChanged = !areArraysEqual(mcpIncludedList, lastMcpIncludedList)

if (!mcpHubChanged && !listChanged) {
return lastFilteredServers
}

lastMcpHub = mcpHub
lastMcpIncludedList = mcpIncludedList

lastFilteredServers = (
mcpIncludedList && mcpIncludedList.length > 0 ? mcpHub.getAllServers() : mcpHub.getServers()
).filter((server) => {
if (mcpIncludedList && mcpIncludedList.length > 0) {
return mcpIncludedList.includes(server.name) && server.status === "connected"
}
return server.status === "connected"
})

return lastFilteredServers
}
function areArraysEqual(arr1?: string[], arr2?: string[]): boolean {
if (!arr1 && !arr2) return true
if (!arr1 || !arr2) return false
if (arr1.length !== arr2.length) return false

return arr1.every((item, index) => item === arr2[index])
}

export async function getMcpServersSection(
mcpHub?: McpHub,
diffStrategy?: DiffStrategy,
enableMcpServerCreation?: boolean,
currentMode?: ModeConfig,
): Promise<string> {
if (!mcpHub) {
return ""
}

const connectedServers =
mcpHub.getServers().length > 0
? `${mcpHub
.getServers()
.filter((server) => server.status === "connected")
.map((server) => {
const tools = server.tools
// Get MCP configuration for current mode
let mcpIncludedList: string[] | undefined

if (currentMode) {
// Find MCP group configuration
const mcpGroup = currentMode.groups.find((group: GroupEntry) => {
if (Array.isArray(group) && group.length === 2 && group[0] === "mcp") {
return true
}
return getGroupName(group) === "mcp"
})

// If MCP group configuration is found, get mcpIncludedList from mcp.included
if (mcpGroup && Array.isArray(mcpGroup) && mcpGroup.length === 2) {
const options = mcpGroup[1] as { mcp?: { included?: unknown[] } }
mcpIncludedList = Array.isArray(options.mcp?.included)
? options.mcp.included.filter((item: unknown): item is string => typeof item === "string")
: undefined
}
}

const filteredServers = memoizeFilteredServers(mcpHub, mcpIncludedList)

let connectedServers: string

if (filteredServers.length > 0) {
connectedServers = `${filteredServers
.map((server) => {
const tools = server.tools
?.filter((tool) => tool.enabledForPrompt !== false)
?.map((tool) => {
const schemaStr = tool.inputSchema
? ` Input Schema:
${JSON.stringify(tool.inputSchema, null, 2).split("\n").join("\n ")}`
: ""

return `- ${tool.name}: ${tool.description}\n${schemaStr}`
})
.join("\n\n")

const templates = server.resourceTemplates
?.map((template) => `- ${template.uriTemplate} (${template.name}): ${template.description}`)
.join("\n")

const resources = server.resources
?.map((resource) => `- ${resource.uri} (${resource.name}): ${resource.description}`)
.join("\n")

const config = JSON.parse(server.config)

return (
`## ${server.name}${config.command ? ` (\`${config.command}${config.args && Array.isArray(config.args) ? ` ${config.args.join(" ")}` : ""}\`)` : ""}` +
(server.instructions ? `\n\n### Instructions\n${server.instructions}` : "") +
(tools ? `\n\n### Available Tools\n${tools}` : "") +
(templates ? `\n\n### Resource Templates\n${templates}` : "") +
(resources ? `\n\n### Direct Resources\n${resources}` : "")
)
?.map((tool) => {
const schemaStr = tool.inputSchema
? ` Input Schema:
${JSON.stringify(tool.inputSchema, null, 2).split("\n").join("\n ")}`
: ""

return `- ${tool.name}: ${tool.description}\n${schemaStr}`
})
.join("\n\n")}`
: "(No MCP servers currently connected)"
.join("\n\n")

const templates = server.resourceTemplates
?.map((template) => `- ${template.uriTemplate} (${template.name}): ${template.description}`)
.join("\n")

const resources = server.resources
?.map((resource) => `- ${resource.uri} (${resource.name}): ${resource.description}`)
.join("\n")

const config = JSON.parse(server.config)

return (
`## ${server.name}${config.command ? ` (\`${config.command}${config.args && Array.isArray(config.args) ? ` ${config.args.join(" ")}` : ""}\`)` : ""}` +
(server.instructions ? `\n\n### Instructions\n${server.instructions}` : "") +
(tools ? `\n\n### Available Tools\n${tools}` : "") +
(templates ? `\n\n### Resource Templates\n${templates}` : "") +
(resources ? `\n\n### Direct Resources\n${resources}` : "")
)
})
.join("\n\n")}`
} else if (mcpIncludedList && mcpIncludedList.length > 0) {
const allServers = mcpHub.getAllServers()
const disconnectedServers = mcpIncludedList
.map((name) => {
const server = allServers.find((s) => s.name === name)
if (server && server.status !== "connected") {
return `- ${server.name} (${server.status})`
}
if (!server) {
return `- ${name} (not found)`
}
return null
})
.filter(Boolean)
.join("\n")
connectedServers = `(Configured MCP servers are not currently connected)${
disconnectedServers ? `\n\nConfigured but disconnected servers:\n${disconnectedServers}` : ""
}`
} else {
connectedServers = "(No MCP servers currently connected)"
}

const baseSection = `MCP SERVERS

Expand Down
2 changes: 1 addition & 1 deletion src/core/prompts/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async function generatePrompt(
const [modesSection, mcpServersSection] = await Promise.all([
getModesSection(context),
shouldIncludeMcp
? getMcpServersSection(mcpHub, effectiveDiffStrategy, enableMcpServerCreation)
? getMcpServersSection(mcpHub, effectiveDiffStrategy, enableMcpServerCreation, modeConfig)
: Promise.resolve(""),
])

Expand Down
Loading
Loading