From e99023a9c73fde2c286267e1494db118a33ad301 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 15 Sep 2025 21:59:42 +0000 Subject: [PATCH 1/2] feat: add MCP server prompts support - Add prompt types to shared/mcp.ts (McpPrompt, McpPromptArgument, etc.) - Update McpHub to fetch prompts from MCP servers via prompts/list - Add getPrompt method to McpHub for executing prompts - Create mcp-prompts.ts service to integrate MCP prompts with slash commands - Update command system to support MCP prompts as slash commands (mcp..) - Add support for prompt arguments/parameters - Update tests to support new function signatures Implements #8004 --- src/__tests__/command-integration.spec.ts | 10 +- src/__tests__/commands.spec.ts | 16 +- src/core/mentions/index.ts | 3 +- src/core/tools/runSlashCommandTool.ts | 52 +++++- src/core/webview/webviewMessageHandler.ts | 7 +- .../__tests__/frontmatter-commands.spec.ts | 26 +-- src/services/command/commands.ts | 36 ++-- src/services/command/mcp-prompts.ts | 161 ++++++++++++++++++ src/services/mcp/McpHub.ts | 51 +++++- src/shared/mcp.ts | 35 ++++ 10 files changed, 353 insertions(+), 44 deletions(-) create mode 100644 src/services/command/mcp-prompts.ts diff --git a/src/__tests__/command-integration.spec.ts b/src/__tests__/command-integration.spec.ts index 59427415f7..f520a1a304 100644 --- a/src/__tests__/command-integration.spec.ts +++ b/src/__tests__/command-integration.spec.ts @@ -6,7 +6,7 @@ describe("Command Integration Tests", () => { const testWorkspaceDir = path.join(__dirname, "../../") it("should discover command files in .roo/commands/", async () => { - const commands = await getCommands(testWorkspaceDir) + const commands = await getCommands(testWorkspaceDir, undefined) // Should be able to discover commands (may be empty in test environment) expect(Array.isArray(commands)).toBe(true) @@ -22,7 +22,7 @@ describe("Command Integration Tests", () => { }) it("should return command names correctly", async () => { - const commandNames = await getCommandNames(testWorkspaceDir) + const commandNames = await getCommandNames(testWorkspaceDir, undefined) // Should return an array (may be empty in test environment) expect(Array.isArray(commandNames)).toBe(true) @@ -35,11 +35,11 @@ describe("Command Integration Tests", () => { }) it("should load command content if commands exist", async () => { - const commands = await getCommands(testWorkspaceDir) + const commands = await getCommands(testWorkspaceDir, undefined) if (commands.length > 0) { const firstCommand = commands[0] - const loadedCommand = await getCommand(testWorkspaceDir, firstCommand.name) + const loadedCommand = await getCommand(testWorkspaceDir, firstCommand.name, undefined) expect(loadedCommand).toBeDefined() expect(loadedCommand?.name).toBe(firstCommand.name) @@ -50,7 +50,7 @@ describe("Command Integration Tests", () => { }) it("should handle non-existent commands gracefully", async () => { - const nonExistentCommand = await getCommand(testWorkspaceDir, "non-existent-command") + const nonExistentCommand = await getCommand(testWorkspaceDir, "non-existent-command", undefined) expect(nonExistentCommand).toBeUndefined() }) }) diff --git a/src/__tests__/commands.spec.ts b/src/__tests__/commands.spec.ts index e3d9e81c23..f087ab337a 100644 --- a/src/__tests__/commands.spec.ts +++ b/src/__tests__/commands.spec.ts @@ -40,21 +40,21 @@ describe("Command Utilities", () => { describe("getCommands", () => { it("should return empty array when no command directories exist", async () => { // This will fail to find directories but should return empty array gracefully - const commands = await getCommands(testCwd) + const commands = await getCommands(testCwd, undefined) expect(Array.isArray(commands)).toBe(true) }) }) describe("getCommandNames", () => { it("should return empty array when no commands exist", async () => { - const names = await getCommandNames(testCwd) + const names = await getCommandNames(testCwd, undefined) expect(Array.isArray(names)).toBe(true) }) }) describe("getCommand", () => { it("should return undefined for non-existent command", async () => { - const result = await getCommand(testCwd, "non-existent") + const result = await getCommand(testCwd, "non-existent", undefined) expect(result).toBeUndefined() }) }) @@ -78,8 +78,8 @@ describe("Command Utilities", () => { describe("command loading behavior", () => { it("should handle multiple calls to getCommands", async () => { - const commands1 = await getCommands(testCwd) - const commands2 = await getCommands(testCwd) + const commands1 = await getCommands(testCwd, undefined) + const commands2 = await getCommands(testCwd, undefined) expect(Array.isArray(commands1)).toBe(true) expect(Array.isArray(commands2)).toBe(true) }) @@ -88,9 +88,9 @@ describe("Command Utilities", () => { describe("error handling", () => { it("should handle invalid command names gracefully", async () => { // These should not throw errors - expect(await getCommand(testCwd, "")).toBeUndefined() - expect(await getCommand(testCwd, " ")).toBeUndefined() - expect(await getCommand(testCwd, "non/existent/path")).toBeUndefined() + expect(await getCommand(testCwd, "", undefined)).toBeUndefined() + expect(await getCommand(testCwd, " ", undefined)).toBeUndefined() + expect(await getCommand(testCwd, "non/existent/path", undefined)).toBeUndefined() }) }) }) diff --git a/src/core/mentions/index.ts b/src/core/mentions/index.ts index f038b5b783..1dd0e977c3 100644 --- a/src/core/mentions/index.ts +++ b/src/core/mentions/index.ts @@ -92,7 +92,8 @@ export async function parseMentions( const commandExistenceChecks = await Promise.all( Array.from(uniqueCommandNames).map(async (commandName) => { try { - const command = await getCommand(cwd, commandName) + // TODO: Pass McpHub instance when available for MCP prompt support + const command = await getCommand(cwd, commandName, undefined) return { commandName, command } } catch (error) { // If there's an error checking command existence, treat it as non-existent diff --git a/src/core/tools/runSlashCommandTool.ts b/src/core/tools/runSlashCommandTool.ts index 06ceb5f19c..ee9f961209 100644 --- a/src/core/tools/runSlashCommandTool.ts +++ b/src/core/tools/runSlashCommandTool.ts @@ -3,6 +3,7 @@ import { ToolUse, AskApproval, HandleError, PushToolResult, RemoveClosingTag } f import { formatResponse } from "../prompts/responses" import { getCommand, getCommandNames } from "../../services/command/commands" import { EXPERIMENT_IDS, experiments } from "../../shared/experiments" +import { McpServerManager } from "../../services/mcp/McpServerManager" export async function runSlashCommandTool( task: Task, @@ -49,12 +50,20 @@ export async function runSlashCommandTool( task.consecutiveMistakeCount = 0 - // Get the command from the commands service - const command = await getCommand(task.cwd, commandName) + // Get the command from the commands service (pass McpHub for MCP prompt support) + let mcpHub = undefined + if (provider) { + try { + mcpHub = await McpServerManager.getInstance(provider.context, provider) + } catch (error) { + console.error("Failed to get MCP hub:", error) + } + } + const command = await getCommand(task.cwd, commandName, mcpHub) if (!command) { // Get available commands for error message - const availableCommands = await getCommandNames(task.cwd) + const availableCommands = await getCommandNames(task.cwd, mcpHub) task.recordToolError("run_slash_command") pushToolResult( formatResponse.toolError( @@ -64,6 +73,41 @@ export async function runSlashCommandTool( return } + // Handle MCP prompt commands differently + let commandContent = command.content + + if (command.source === "mcp" && command.name.startsWith("mcp.") && mcpHub) { + const parts = command.name.split(".") + if (parts.length >= 3) { + const serverName = parts[1] + const promptName = parts.slice(2).join(".") + + try { + const { executeMcpPrompt, parsePromptArguments } = await import( + "../../services/command/mcp-prompts" + ) + + // Parse arguments if provided + let promptArgs: Record = {} + if (args) { + const servers = mcpHub.getAllServers() + const server = servers.find((s) => s.name === serverName) + const prompt = server?.prompts?.find((p) => p.name === promptName) + + if (prompt) { + promptArgs = parsePromptArguments(prompt, args) + } + } + + // Execute the MCP prompt to get the actual content + commandContent = await executeMcpPrompt(mcpHub, serverName, promptName, promptArgs) + } catch (error) { + console.error(`Failed to execute MCP prompt ${command.name}:`, error) + commandContent = `Error executing MCP prompt: ${error instanceof Error ? error.message : String(error)}` + } + } + } + const toolMessage = JSON.stringify({ tool: "runSlashCommand", command: commandName, @@ -94,7 +138,7 @@ export async function runSlashCommandTool( } result += `\nSource: ${command.source}` - result += `\n\n--- Command Content ---\n\n${command.content}` + result += `\n\n--- Command Content ---\n\n${commandContent}` // Return the command content as the tool result pushToolResult(result) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index abdfae29fa..bd13aa54e5 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -2817,7 +2817,10 @@ export const webviewMessageHandler = async ( try { if (message.text) { const { getCommand } = await import("../../services/command/commands") - const command = await getCommand(getCurrentCwd(), message.text) + const { executeMcpPrompt, parsePromptArguments } = await import( + "../../services/command/mcp-prompts" + ) + const command = await getCommand(getCurrentCwd(), message.text, provider.mcpHub) if (command && command.filePath) { openFile(command.filePath) @@ -2954,7 +2957,7 @@ export const webviewMessageHandler = async ( // Refresh commands list const { getCommands } = await import("../../services/command/commands") - const commands = await getCommands(getCurrentCwd() || "") + const commands = await getCommands(getCurrentCwd() || "", provider.mcpHub) const commandList = commands.map((command) => ({ name: command.name, source: command.source, diff --git a/src/services/command/__tests__/frontmatter-commands.spec.ts b/src/services/command/__tests__/frontmatter-commands.spec.ts index 40acc8ae84..1aaf9d571c 100644 --- a/src/services/command/__tests__/frontmatter-commands.spec.ts +++ b/src/services/command/__tests__/frontmatter-commands.spec.ts @@ -40,7 +40,7 @@ npm run build mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "setup") + const result = await getCommand("/test/cwd", "setup", undefined) expect(result).toEqual({ name: "setup", @@ -64,7 +64,7 @@ npm run build mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "setup") + const result = await getCommand("/test/cwd", "setup", undefined) expect(result).toEqual({ name: "setup", @@ -89,7 +89,7 @@ Command content here.` mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "setup") + const result = await getCommand("/test/cwd", "setup", undefined) expect(result?.description).toBeUndefined() }) @@ -107,7 +107,7 @@ Command content here.` mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "setup") + const result = await getCommand("/test/cwd", "setup", undefined) expect(result).toEqual({ name: "setup", @@ -142,7 +142,7 @@ Global setup instructions.` .mockResolvedValueOnce(projectCommandContent) // First call for project .mockResolvedValueOnce(globalCommandContent) // Second call for global (shouldn't be used) - const result = await getCommand("/test/cwd", "setup") + const result = await getCommand("/test/cwd", "setup", undefined) expect(result).toEqual({ name: "setup", @@ -169,7 +169,7 @@ Global setup instructions.` .mockRejectedValueOnce(new Error("File not found")) // Project command doesn't exist .mockResolvedValueOnce(globalCommandContent) // Global command exists - const result = await getCommand("/test/cwd", "setup") + const result = await getCommand("/test/cwd", "setup", undefined) expect(result).toEqual({ name: "setup", @@ -196,7 +196,7 @@ Create a new release.` mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "release") + const result = await getCommand("/test/cwd", "release", undefined) expect(result).toEqual({ name: "release", @@ -222,7 +222,7 @@ Deploy the application.` mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "deploy") + const result = await getCommand("/test/cwd", "deploy", undefined) expect(result).toEqual({ name: "deploy", @@ -247,7 +247,7 @@ Test content.` mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "test") + const result = await getCommand("/test/cwd", "test", undefined) expect(result?.argumentHint).toBeUndefined() }) @@ -265,7 +265,7 @@ Test content.` mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "test") + const result = await getCommand("/test/cwd", "test", undefined) expect(result?.argumentHint).toBeUndefined() }) @@ -283,7 +283,7 @@ Test content.` mockFs.stat = vi.fn().mockResolvedValue({ isDirectory: () => true }) mockFs.readFile = vi.fn().mockResolvedValue(commandContent) - const result = await getCommand("/test/cwd", "test") + const result = await getCommand("/test/cwd", "test", undefined) expect(result?.argumentHint).toBeUndefined() }) @@ -324,7 +324,7 @@ Build instructions without frontmatter.` .mockResolvedValueOnce(deployContent) .mockResolvedValueOnce(buildContent) - const result = await getCommands("/test/cwd") + const result = await getCommands("/test/cwd", undefined) expect(result).toHaveLength(3) expect(result).toEqual( @@ -374,7 +374,7 @@ Deploy the app.` ]) mockFs.readFile = vi.fn().mockResolvedValueOnce(releaseContent).mockResolvedValueOnce(deployContent) - const result = await getCommands("/test/cwd") + const result = await getCommands("/test/cwd", undefined) expect(result).toHaveLength(2) expect(result).toEqual( diff --git a/src/services/command/commands.ts b/src/services/command/commands.ts index 1cd5434745..0eb234692d 100644 --- a/src/services/command/commands.ts +++ b/src/services/command/commands.ts @@ -3,21 +3,23 @@ import * as path from "path" import matter from "gray-matter" import { getGlobalRooDirectory, getProjectRooDirectoryForCwd } from "../roo-config" import { getBuiltInCommands, getBuiltInCommand } from "./built-in-commands" +import { getMcpPromptsAsCommands, getMcpPromptCommand } from "./mcp-prompts" +import { McpHub } from "../mcp/McpHub" export interface Command { name: string content: string - source: "global" | "project" | "built-in" + source: "global" | "project" | "built-in" | "mcp" filePath: string description?: string argumentHint?: string } /** - * Get all available commands from built-in, global, and project directories - * Priority order: project > global > built-in (later sources override earlier ones) + * Get all available commands from built-in, global, project directories, and MCP servers + * Priority order: MCP prompts > project > global > built-in (later sources override earlier ones) */ -export async function getCommands(cwd: string): Promise { +export async function getCommands(cwd: string, mcpHub?: McpHub): Promise { const commands = new Map() // Add built-in commands first (lowest priority) @@ -30,23 +32,37 @@ export async function getCommands(cwd: string): Promise { const globalDir = path.join(getGlobalRooDirectory(), "commands") await scanCommandDirectory(globalDir, "global", commands) - // Scan project commands (highest priority - override both global and built-in) + // Scan project commands (override both global and built-in) const projectDir = path.join(getProjectRooDirectoryForCwd(cwd), "commands") await scanCommandDirectory(projectDir, "project", commands) + // Add MCP prompts as commands (highest priority - override all others) + const mcpCommands = await getMcpPromptsAsCommands(mcpHub) + for (const command of mcpCommands) { + commands.set(command.name, { ...command, source: "mcp" }) + } + return Array.from(commands.values()) } /** * Get a specific command by name (optimized to avoid scanning all commands) - * Priority order: project > global > built-in + * Priority order: MCP prompts > project > global > built-in */ -export async function getCommand(cwd: string, name: string): Promise { +export async function getCommand(cwd: string, name: string, mcpHub?: McpHub): Promise { + // Check if it's an MCP prompt command first (highest priority) + if (name.startsWith("mcp.") && mcpHub) { + const mcpCommand = await getMcpPromptCommand(mcpHub, name) + if (mcpCommand) { + return { ...mcpCommand, source: "mcp" } + } + } + // Try to find the command directly without scanning all commands const projectDir = path.join(getProjectRooDirectoryForCwd(cwd), "commands") const globalDir = path.join(getGlobalRooDirectory(), "commands") - // Check project directory first (highest priority) + // Check project directory first const projectCommand = await tryLoadCommand(projectDir, name, "project") if (projectCommand) { return projectCommand @@ -128,8 +144,8 @@ async function tryLoadCommand( /** * Get command names for autocomplete */ -export async function getCommandNames(cwd: string): Promise { - const commands = await getCommands(cwd) +export async function getCommandNames(cwd: string, mcpHub?: McpHub): Promise { + const commands = await getCommands(cwd, mcpHub) return commands.map((cmd) => cmd.name) } diff --git a/src/services/command/mcp-prompts.ts b/src/services/command/mcp-prompts.ts new file mode 100644 index 0000000000..1b03235528 --- /dev/null +++ b/src/services/command/mcp-prompts.ts @@ -0,0 +1,161 @@ +import { Command } from "./commands" +import { McpHub } from "../mcp/McpHub" +import { McpPrompt } from "../../shared/mcp" + +/** + * Convert MCP prompts to commands that can be used in the slash command system + */ +export async function getMcpPromptsAsCommands(mcpHub: McpHub | undefined): Promise { + if (!mcpHub) { + return [] + } + + const commands: Command[] = [] + const servers = mcpHub.getAllServers() + + for (const server of servers) { + if (server.disabled || server.status !== "connected" || !server.prompts) { + continue + } + + // Add each prompt as a command with the pattern: mcp.. + for (const prompt of server.prompts) { + const commandName = `mcp.${server.name}.${prompt.name}` + commands.push({ + name: commandName, + content: "", // Content will be fetched dynamically when the command is used + source: server.source === "project" ? "project" : "global", + filePath: "", // Virtual command, no file path + description: prompt.description || `MCP prompt from ${server.name}`, + argumentHint: getArgumentHint(prompt), + }) + } + } + + return commands +} + +/** + * Get a specific MCP prompt command by name + */ +export async function getMcpPromptCommand( + mcpHub: McpHub | undefined, + commandName: string, +): Promise { + if (!mcpHub || !commandName.startsWith("mcp.")) { + return undefined + } + + // Parse the command name: mcp.. + const parts = commandName.split(".") + if (parts.length < 3) { + return undefined + } + + const serverName = parts[1] + const promptName = parts.slice(2).join(".") // Handle prompt names with dots + + const servers = mcpHub.getAllServers() + const server = servers.find((s) => s.name === serverName) + + if (!server || server.disabled || server.status !== "connected" || !server.prompts) { + return undefined + } + + const prompt = server.prompts.find((p) => p.name === promptName) + if (!prompt) { + return undefined + } + + return { + name: commandName, + content: "", // Content will be fetched dynamically when the command is used + source: server.source === "project" ? "project" : "global", + filePath: "", // Virtual command, no file path + description: prompt.description || `MCP prompt from ${server.name}`, + argumentHint: getArgumentHint(prompt), + } +} + +/** + * Execute an MCP prompt and get the resulting content + */ +export async function executeMcpPrompt( + mcpHub: McpHub, + serverName: string, + promptName: string, + args?: Record, +): Promise { + try { + const response = await mcpHub.getPrompt(serverName, promptName, args) + + // Convert the prompt response to a string that can be used as command content + if (response.messages && response.messages.length > 0) { + // Combine all messages into a single string + const content = response.messages + .map((msg) => { + if (msg.content.type === "text" && msg.content.text) { + return msg.content.text + } else if (msg.content.type === "resource" && msg.content.resource?.text) { + return msg.content.resource.text + } + return "" + }) + .filter((text) => text.length > 0) + .join("\n\n") + + return content || "No content returned from MCP prompt" + } + + return "No messages returned from MCP prompt" + } catch (error) { + console.error(`Failed to execute MCP prompt ${promptName} on server ${serverName}:`, error) + throw new Error(`Failed to execute MCP prompt: ${error instanceof Error ? error.message : String(error)}`) + } +} + +/** + * Get argument hint for a prompt based on its arguments + */ +function getArgumentHint(prompt: McpPrompt): string | undefined { + if (!prompt.arguments || prompt.arguments.length === 0) { + return undefined + } + + const requiredArgs = prompt.arguments.filter((arg) => arg.required !== false) + const optionalArgs = prompt.arguments.filter((arg) => arg.required === false) + + const hints: string[] = [] + + if (requiredArgs.length > 0) { + hints.push(requiredArgs.map((arg) => `<${arg.name}>`).join(" ")) + } + + if (optionalArgs.length > 0) { + hints.push(optionalArgs.map((arg) => `[${arg.name}]`).join(" ")) + } + + return hints.join(" ") || undefined +} + +/** + * Parse arguments from a command string + */ +export function parsePromptArguments(prompt: McpPrompt, argsString: string): Record { + if (!prompt.arguments || prompt.arguments.length === 0) { + return {} + } + + const args: Record = {} + const parts = argsString.trim().split(/\s+/) + + // Simple positional argument parsing + // In a more sophisticated implementation, we could support named arguments + prompt.arguments.forEach((arg, index) => { + if (index < parts.length) { + args[arg.name] = parts[index] + } + }) + + return args +} diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index caca5ddb39..0c1afdccc6 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -9,6 +9,8 @@ import { ListResourceTemplatesResultSchema, ListToolsResultSchema, ReadResourceResultSchema, + ListPromptsResultSchema, + GetPromptResultSchema, } from "@modelcontextprotocol/sdk/types.js" import chokidar, { FSWatcher } from "chokidar" import delay from "delay" @@ -28,6 +30,8 @@ import { McpServer, McpTool, McpToolCallResponse, + McpPrompt, + McpGetPromptResponse, } from "../../shared/mcp" import { fileExistsAtPath } from "../../utils/fs" import { arePathsEqual, getWorkspacePath } from "../../utils/path" @@ -835,10 +839,11 @@ export class McpHub { connection.server.error = "" connection.server.instructions = client.getInstructions() - // Initial fetch of tools and resources + // Initial fetch of tools, resources, and prompts connection.server.tools = await this.fetchToolsList(name, source) connection.server.resources = await this.fetchResourcesList(name, source) connection.server.resourceTemplates = await this.fetchResourceTemplatesList(name, source) + connection.server.prompts = await this.fetchPromptsList(name, source) } catch (error) { // Update status with error const connection = this.findConnection(name, source) @@ -993,6 +998,49 @@ export class McpHub { } } + private async fetchPromptsList(serverName: string, source?: "global" | "project"): Promise { + try { + const connection = this.findConnection(serverName, source) + if (!connection || connection.type !== "connected") { + return [] + } + const response = await connection.client.request({ method: "prompts/list" }, ListPromptsResultSchema) + return response?.prompts || [] + } catch (error) { + // Prompts might not be supported by all servers, so we silently handle errors + // console.error(`Failed to fetch prompts for ${serverName}:`, error) + return [] + } + } + + async getPrompt( + serverName: string, + promptName: string, + args?: Record, + source?: "global" | "project", + ): Promise { + const connection = this.findConnection(serverName, source) + if (!connection || connection.type !== "connected") { + throw new Error( + `No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`, + ) + } + if (connection.server.disabled) { + throw new Error(`Server "${serverName}" is disabled and cannot be used`) + } + + return await connection.client.request( + { + method: "prompts/get", + params: { + name: promptName, + arguments: args, + }, + }, + GetPromptResultSchema, + ) + } + async deleteConnection(name: string, source?: "global" | "project"): Promise { // Clean up file watchers for this server this.removeFileWatchersForServer(name) @@ -1384,6 +1432,7 @@ export class McpHub { serverName, serverSource, ) + connection.server.prompts = await this.fetchPromptsList(serverName, serverSource) } } catch (error) { console.error(`Failed to refresh capabilities for ${serverName}:`, error) diff --git a/src/shared/mcp.ts b/src/shared/mcp.ts index ef1d51bad3..3ba950074e 100644 --- a/src/shared/mcp.ts +++ b/src/shared/mcp.ts @@ -13,6 +13,7 @@ export type McpServer = { tools?: McpTool[] resources?: McpResource[] resourceTemplates?: McpResourceTemplate[] + prompts?: McpPrompt[] disabled?: boolean timeout?: number source?: "global" | "project" @@ -42,6 +43,40 @@ export type McpResourceTemplate = { mimeType?: string } +export type McpPrompt = { + name: string + description?: string + arguments?: McpPromptArgument[] +} + +export type McpPromptArgument = { + name: string + description?: string + required?: boolean +} + +export type McpPromptMessage = { + role: "user" | "assistant" | "system" + content: { + type: "text" | "image" | "resource" + text?: string + data?: string + mimeType?: string + resource?: { + uri: string + text?: string + blob?: string + mimeType?: string + } + } +} + +export type McpGetPromptResponse = { + _meta?: Record + description?: string + messages: McpPromptMessage[] +} + export type McpResourceResponse = { _meta?: Record contents: Array<{ From fe96d0db102a599918cf891680bdc5e10c09e361 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 15 Sep 2025 22:07:20 +0000 Subject: [PATCH 2/2] test: add comprehensive tests for MCP prompts functionality - Test getMcpPromptsAsCommands for converting prompts to commands - Test executeMcpPrompt for executing prompts with arguments - Test parsePromptArguments for parsing positional arguments - Ensure proper handling of edge cases and errors --- src/__tests__/command-mentions.spec.ts | 10 +- .../command/__tests__/mcp-prompts.spec.ts | 392 ++++++++++++++++++ src/services/command/mcp-prompts.ts | 7 +- src/shared/ExtensionMessage.ts | 2 +- 4 files changed, 404 insertions(+), 7 deletions(-) create mode 100644 src/services/command/__tests__/mcp-prompts.spec.ts diff --git a/src/__tests__/command-mentions.spec.ts b/src/__tests__/command-mentions.spec.ts index 7ddaf3d092..4026d6ebb2 100644 --- a/src/__tests__/command-mentions.spec.ts +++ b/src/__tests__/command-mentions.spec.ts @@ -53,7 +53,7 @@ describe("Command Mentions", () => { const input = "/setup Please help me set up the project" const result = await callParseMentions(input) - expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "setup") + expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "setup", undefined) expect(result).toContain('') expect(result).toContain(commandContent) expect(result).toContain("") @@ -94,8 +94,8 @@ describe("Command Mentions", () => { const input = "/setup the project\nThen /deploy later" const result = await callParseMentions(input) - expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "setup") - expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "deploy") + expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "setup", undefined) + expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "deploy", undefined) expect(mockGetCommand).toHaveBeenCalledTimes(2) // Each unique command called once (optimized) expect(result).toContain('') expect(result).toContain("# Setup Environment") @@ -110,7 +110,7 @@ describe("Command Mentions", () => { const input = "/nonexistent command" const result = await callParseMentions(input) - expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "nonexistent") + expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "nonexistent", undefined) // The command should remain unchanged in the text expect(result).toBe("/nonexistent command") // Should not contain any command tags @@ -159,7 +159,7 @@ describe("Command Mentions", () => { const input = "/setup-dev for the project" const result = await callParseMentions(input) - expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "setup-dev") + expect(mockGetCommand).toHaveBeenCalledWith("/test/cwd", "setup-dev", undefined) expect(result).toContain('') expect(result).toContain("# Dev setup") }) diff --git a/src/services/command/__tests__/mcp-prompts.spec.ts b/src/services/command/__tests__/mcp-prompts.spec.ts new file mode 100644 index 0000000000..9e7d230c59 --- /dev/null +++ b/src/services/command/__tests__/mcp-prompts.spec.ts @@ -0,0 +1,392 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { getMcpPromptsAsCommands, executeMcpPrompt, parsePromptArguments } from "../mcp-prompts" +import { McpHub } from "../../mcp/McpHub" +import type { McpServer, McpPrompt } from "../../../shared/mcp" + +// Mock McpHub +vi.mock("../../mcp/McpHub") + +describe("MCP Prompts", () => { + let mockMcpHub: any + + beforeEach(() => { + vi.clearAllMocks() + mockMcpHub = { + getAllServers: vi.fn(), + getPrompt: vi.fn(), + } + }) + + describe("getMcpPromptsAsCommands", () => { + it("should return empty array when mcpHub is undefined", async () => { + const result = await getMcpPromptsAsCommands(undefined) + expect(result).toEqual([]) + }) + + it("should return empty array when no servers have prompts", async () => { + const servers: McpServer[] = [ + { + name: "test-server", + config: "test-config", + status: "connected", + tools: [], + resources: [], + prompts: [], + }, + ] + mockMcpHub.getAllServers.mockReturnValue(servers) + + const result = await getMcpPromptsAsCommands(mockMcpHub) + expect(result).toEqual([]) + }) + + it("should convert MCP prompts to commands", async () => { + const prompts: McpPrompt[] = [ + { + name: "generate-code", + description: "Generate code based on requirements", + arguments: [ + { + name: "language", + description: "Programming language", + required: true, + }, + ], + }, + { + name: "explain", + description: "Explain a concept", + arguments: [], + }, + ] + + const servers: McpServer[] = [ + { + name: "coding-assistant", + config: "test-config", + status: "connected", + tools: [], + resources: [], + prompts, + source: "global", + }, + ] + mockMcpHub.getAllServers.mockReturnValue(servers) + + const result = await getMcpPromptsAsCommands(mockMcpHub) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual({ + name: "mcp.coding-assistant.generate-code", + content: "", + source: "global", + filePath: "", + description: "Generate code based on requirements", + argumentHint: "", + }) + expect(result[1]).toEqual({ + name: "mcp.coding-assistant.explain", + content: "", + source: "global", + filePath: "", + description: "Explain a concept", + argumentHint: undefined, + }) + }) + + it("should handle multiple servers with prompts", async () => { + const server1: McpServer = { + name: "server1", + config: "test-config", + status: "connected", + tools: [], + resources: [], + prompts: [ + { + name: "prompt1", + description: "First prompt", + arguments: [], + }, + ], + source: "project", + } + + const server2: McpServer = { + name: "server2", + config: "test-config", + status: "connected", + tools: [], + resources: [], + prompts: [ + { + name: "prompt2", + description: "Second prompt", + arguments: [ + { + name: "param", + description: "A parameter", + required: false, + }, + ], + }, + ], + source: "global", + } + + mockMcpHub.getAllServers.mockReturnValue([server1, server2]) + + const result = await getMcpPromptsAsCommands(mockMcpHub) + + expect(result).toHaveLength(2) + expect(result[0].name).toBe("mcp.server1.prompt1") + expect(result[1].name).toBe("mcp.server2.prompt2") + expect(result[1].argumentHint).toBe("[param]") + }) + + it("should handle prompts with multiple arguments", async () => { + const prompts: McpPrompt[] = [ + { + name: "complex-prompt", + description: "A complex prompt", + arguments: [ + { + name: "arg1", + description: "First argument", + required: true, + }, + { + name: "arg2", + description: "Second argument", + required: false, + }, + { + name: "arg3", + description: "Third argument", + required: true, + }, + ], + }, + ] + + const servers: McpServer[] = [ + { + name: "test", + config: "test-config", + status: "connected", + tools: [], + resources: [], + prompts, + source: "global", + }, + ] + mockMcpHub.getAllServers.mockReturnValue(servers) + + const result = await getMcpPromptsAsCommands(mockMcpHub) + + expect(result[0].argumentHint).toBe(" [arg2]") + }) + }) + + describe("executeMcpPrompt", () => { + it("should execute prompt without arguments", async () => { + const mockResponse = { + messages: [ + { + content: { + type: "text" as const, + text: "Generated content", + }, + }, + ], + } + mockMcpHub.getPrompt.mockResolvedValue(mockResponse) + + const result = await executeMcpPrompt(mockMcpHub, "test-server", "test-prompt", {}) + + expect(mockMcpHub.getPrompt).toHaveBeenCalledWith("test-server", "test-prompt", {}) + expect(result).toBe("Generated content") + }) + + it("should execute prompt with arguments", async () => { + const mockResponse = { + messages: [ + { + content: { + type: "text" as const, + text: "Result with args", + }, + }, + ], + } + mockMcpHub.getPrompt.mockResolvedValue(mockResponse) + + const result = await executeMcpPrompt(mockMcpHub, "server", "prompt", { + language: "python", + description: "A test function", + }) + + expect(mockMcpHub.getPrompt).toHaveBeenCalledWith("server", "prompt", { + language: "python", + description: "A test function", + }) + expect(result).toBe("Result with args") + }) + + it("should handle multiple messages in response", async () => { + const mockResponse = { + messages: [ + { + content: { + type: "text" as const, + text: "First message", + }, + }, + { + content: { + type: "text" as const, + text: "Second message", + }, + }, + ], + } + mockMcpHub.getPrompt.mockResolvedValue(mockResponse) + + const result = await executeMcpPrompt(mockMcpHub, "server", "prompt", {}) + + expect(result).toBe("First message\n\nSecond message") + }) + + it("should handle resource content types", async () => { + const mockResponse = { + messages: [ + { + content: { + type: "resource" as const, + resource: { + text: "Resource content", + }, + }, + }, + ], + } + mockMcpHub.getPrompt.mockResolvedValue(mockResponse) + + const result = await executeMcpPrompt(mockMcpHub, "server", "prompt", {}) + + expect(result).toBe("Resource content") + }) + + it("should handle errors gracefully", async () => { + mockMcpHub.getPrompt.mockRejectedValue(new Error("Connection failed")) + + await expect(executeMcpPrompt(mockMcpHub, "server", "prompt", {})).rejects.toThrow( + "Failed to execute MCP prompt: Connection failed", + ) + }) + + it("should handle empty response", async () => { + const mockResponse = { + messages: [], + } + mockMcpHub.getPrompt.mockResolvedValue(mockResponse) + + const result = await executeMcpPrompt(mockMcpHub, "server", "prompt", {}) + + expect(result).toBe("No messages returned from MCP prompt") + }) + + it("should handle messages without text content", async () => { + const mockResponse = { + messages: [ + { + content: { + type: "text" as const, + text: "", + }, + }, + { + content: { + type: "resource" as const, + resource: {}, + }, + }, + ], + } + mockMcpHub.getPrompt.mockResolvedValue(mockResponse) + + const result = await executeMcpPrompt(mockMcpHub, "server", "prompt", {}) + + expect(result).toBe("No content returned from MCP prompt") + }) + }) + + describe("parsePromptArguments", () => { + it("should return empty object for prompts without arguments", () => { + const prompt: McpPrompt = { + name: "test", + description: "Test prompt", + arguments: [], + } + + const result = parsePromptArguments(prompt, "some args") + expect(result).toEqual({}) + }) + + it("should parse positional arguments", () => { + const prompt: McpPrompt = { + name: "test", + description: "Test prompt", + arguments: [ + { name: "language", description: "Language", required: true }, + { name: "framework", description: "Framework", required: false }, + ], + } + + const result = parsePromptArguments(prompt, "python django") + expect(result).toEqual({ + language: "python", + framework: "django", + }) + }) + + it("should handle missing optional arguments", () => { + const prompt: McpPrompt = { + name: "test", + description: "Test prompt", + arguments: [ + { name: "required", description: "Required arg", required: true }, + { name: "optional", description: "Optional arg", required: false }, + ], + } + + const result = parsePromptArguments(prompt, "value1") + expect(result).toEqual({ + required: "value1", + }) + }) + + it("should handle extra arguments", () => { + const prompt: McpPrompt = { + name: "test", + description: "Test prompt", + arguments: [{ name: "arg1", description: "Arg 1", required: true }], + } + + const result = parsePromptArguments(prompt, "value1 value2 value3") + expect(result).toEqual({ + arg1: "value1", + }) + }) + + it("should handle empty argument string", () => { + const prompt: McpPrompt = { + name: "test", + description: "Test prompt", + arguments: [{ name: "arg1", description: "Arg 1", required: false }], + } + + const result = parsePromptArguments(prompt, "") + expect(result).toEqual({}) + }) + }) +}) diff --git a/src/services/command/mcp-prompts.ts b/src/services/command/mcp-prompts.ts index 1b03235528..45a93cef62 100644 --- a/src/services/command/mcp-prompts.ts +++ b/src/services/command/mcp-prompts.ts @@ -146,8 +146,13 @@ export function parsePromptArguments(prompt: McpPrompt, argsString: string): Rec return {} } + const trimmedArgs = argsString.trim() + if (trimmedArgs.length === 0) { + return {} + } + const args: Record = {} - const parts = argsString.trim().split(/\s+/) + const parts = trimmedArgs.split(/\s+/) // Simple positional argument parsing // In a more sophisticated implementation, we could support named arguments diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index aaddc520cb..951fa8581c 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -24,7 +24,7 @@ import { ModelRecord, RouterModels } from "./api" // Command interface for frontend/backend communication export interface Command { name: string - source: "global" | "project" | "built-in" + source: "global" | "project" | "built-in" | "mcp" filePath?: string description?: string argumentHint?: string