diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index 7e79855f7e12..c8e9718c9d53 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -138,6 +138,8 @@ export const globalSettingsSchema = z.object({ mcpEnabled: z.boolean().optional(), enableMcpServerCreation: z.boolean().optional(), + mcpMaxImagesPerResponse: z.number().min(0).max(20).optional(), + mcpMaxImageSizeMB: z.number().min(0.1).max(10).optional(), mode: z.string().optional(), modeApiConfigs: z.record(z.string(), z.string()).optional(), @@ -314,6 +316,8 @@ export const EVALS_SETTINGS: RooCodeSettings = { telemetrySetting: "enabled", mcpEnabled: false, + mcpMaxImagesPerResponse: 5, + mcpMaxImageSizeMB: 2, mode: "code", // "architect", diff --git a/packages/types/src/mcp.ts b/packages/types/src/mcp.ts index ed930f4a16ef..bb605530a63e 100644 --- a/packages/types/src/mcp.ts +++ b/packages/types/src/mcp.ts @@ -25,16 +25,20 @@ export const mcpExecutionStatusSchema = z.discriminatedUnion("status", [ executionId: z.string(), status: z.literal("output"), response: z.string(), + images: z.array(z.string()).optional(), }), z.object({ executionId: z.string(), status: z.literal("completed"), response: z.string().optional(), + images: z.array(z.string()).optional(), }), z.object({ executionId: z.string(), status: z.literal("error"), error: z.string().optional(), + response: z.string().optional(), + images: z.array(z.string()).optional(), }), ]) diff --git a/src/core/tools/__tests__/mcpImageHandling.test.ts b/src/core/tools/__tests__/mcpImageHandling.test.ts new file mode 100644 index 000000000000..2ab490f42313 --- /dev/null +++ b/src/core/tools/__tests__/mcpImageHandling.test.ts @@ -0,0 +1,223 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { + SUPPORTED_IMAGE_TYPES, + DEFAULT_MCP_IMAGE_LIMITS, + isSupportedImageType, + isValidBase64Image, + calculateBase64Size, + bytesToMB, + extractMimeType, +} from "../mcpImageConstants" + +describe("MCP Image Constants", () => { + describe("SUPPORTED_IMAGE_TYPES", () => { + it("should include all common image MIME types", () => { + expect(SUPPORTED_IMAGE_TYPES).toContain("image/png") + expect(SUPPORTED_IMAGE_TYPES).toContain("image/jpeg") + expect(SUPPORTED_IMAGE_TYPES).toContain("image/gif") + expect(SUPPORTED_IMAGE_TYPES).toContain("image/webp") + expect(SUPPORTED_IMAGE_TYPES).toContain("image/svg+xml") + expect(SUPPORTED_IMAGE_TYPES).toContain("image/bmp") + }) + }) + + describe("DEFAULT_MCP_IMAGE_LIMITS", () => { + it("should have reasonable default limits", () => { + expect(DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse).toBe(5) + expect(DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB).toBe(2) + }) + }) + + describe("isSupportedImageType", () => { + it("should return true for supported MIME types", () => { + expect(isSupportedImageType("image/png")).toBe(true) + expect(isSupportedImageType("image/jpeg")).toBe(true) + expect(isSupportedImageType("image/gif")).toBe(true) + expect(isSupportedImageType("image/webp")).toBe(true) + expect(isSupportedImageType("image/svg+xml")).toBe(true) + expect(isSupportedImageType("image/bmp")).toBe(true) + }) + + it("should return false for unsupported MIME types", () => { + expect(isSupportedImageType("image/tiff")).toBe(false) + expect(isSupportedImageType("application/pdf")).toBe(false) + expect(isSupportedImageType("text/plain")).toBe(false) + expect(isSupportedImageType("video/mp4")).toBe(false) + }) + + it("should handle edge cases", () => { + expect(isSupportedImageType("")).toBe(false) + expect(isSupportedImageType("IMAGE/PNG")).toBe(false) // Case sensitive + }) + }) + + describe("isValidBase64Image", () => { + // Mock atob for Node.js environment + beforeEach(() => { + if (typeof global.atob === "undefined") { + global.atob = (str: string) => Buffer.from(str, "base64").toString("binary") + } + }) + + it("should validate correct base64 image data", () => { + // Valid PNG data URL + const validPngDataUrl = + "" + expect(isValidBase64Image(validPngDataUrl)).toBe(true) + + // Valid base64 without data URL prefix + const validBase64 = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + expect(isValidBase64Image(validBase64)).toBe(true) + }) + + it("should reject invalid base64 data", () => { + // Invalid base64 characters + expect(isValidBase64Image("!@#$%^&*()")).toBe(false) + + // Not base64 at all + expect(isValidBase64Image("not base64 data")).toBe(false) + + // Empty string - returns true because empty base64 is technically valid + expect(isValidBase64Image("")).toBe(true) + + // Malformed data URL + expect(isValidBase64Image("data:image/png;base64,!!!invalid!!!")).toBe(false) + }) + + it("should handle corrupted base64 data", () => { + // Missing padding - still valid base64 + const corruptedBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ" + // This is actually valid base64 (padding is optional in some implementations) + expect(isValidBase64Image(corruptedBase64)).toBe(true) + + // Truncated data - not multiple of 4 + const truncatedBase64 = "iVBORw0KGg=" + expect(isValidBase64Image(truncatedBase64)).toBe(false) + }) + }) + + describe("calculateBase64Size", () => { + it("should calculate size for base64 string without data URL prefix", () => { + // Base64 string of known size (approximately) + const base64 = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + const size = calculateBase64Size(base64) + // Base64 encoding increases size by ~33%, so we expect around 75% of string length + expect(size).toBeGreaterThan(0) + expect(size).toBeLessThan(base64.length) + }) + + it("should calculate size for data URL", () => { + const dataUrl = + "" + const size = calculateBase64Size(dataUrl) + expect(size).toBeGreaterThan(0) + // Should only count the base64 part, not the prefix + expect(size).toBeLessThan(dataUrl.length) + }) + + it("should handle edge cases", () => { + expect(calculateBase64Size("")).toBe(0) + expect(calculateBase64Size("data:image/png;base64,")).toBe(0) + }) + }) + + describe("bytesToMB", () => { + it("should convert bytes to megabytes correctly", () => { + expect(bytesToMB(1048576)).toBe(1) // 1 MB + expect(bytesToMB(2097152)).toBe(2) // 2 MB + expect(bytesToMB(524288)).toBe(0.5) // 0.5 MB + expect(bytesToMB(0)).toBe(0) + }) + + it("should handle decimal values", () => { + expect(bytesToMB(1572864)).toBeCloseTo(1.5, 2) // 1.5 MB + expect(bytesToMB(3145728)).toBeCloseTo(3, 2) // 3 MB + }) + }) + + describe("extractMimeType", () => { + it("should extract MIME type from data URL", () => { + expect(extractMimeType("...")).toBe("image/png") + expect(extractMimeType("...")).toBe("image/jpeg") + expect(extractMimeType("...")).toBe("image/gif") + }) + + it("should return null for non-data URLs", () => { + expect(extractMimeType("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB")).toBeNull() + expect(extractMimeType("not a data url")).toBeNull() + expect(extractMimeType("")).toBeNull() + }) + + it("should handle malformed data URLs", () => { + expect(extractMimeType("data:;base64,iVBORw0KG...")).toBeNull() + expect(extractMimeType("" + + // Extract MIME type + const mimeType = extractMimeType(validImage) + expect(mimeType).toBe("image/png") + + // Check if supported + expect(isSupportedImageType(mimeType!)).toBe(true) + + // Validate base64 + expect(isValidBase64Image(validImage)).toBe(true) + + // Check size + const sizeBytes = calculateBase64Size(validImage) + const sizeMB = bytesToMB(sizeBytes) + expect(sizeMB).toBeLessThan(DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB) + }) + + it("should reject an oversized image", () => { + // Create a large base64 string (simulating > 2MB) + const largeBase64 = "data:image/png;base64," + "A".repeat(3 * 1024 * 1024) // ~3MB of 'A's + + const sizeBytes = calculateBase64Size(largeBase64) + const sizeMB = bytesToMB(sizeBytes) + + expect(sizeMB).toBeGreaterThan(DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB) + }) + + it("should handle multiple images respecting count limits", () => { + const images = [ + "", + "", + "", + "", + "", + "", // 6th image + ] + + const processedImages: string[] = [] + const errors: string[] = [] + + for (const image of images) { + if (processedImages.length >= DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse) { + errors.push(`Maximum number of images (${DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse}) exceeded`) + continue + } + processedImages.push(image) + } + + expect(processedImages).toHaveLength(5) + expect(errors).toHaveLength(1) + expect(errors[0]).toContain("Maximum number of images") + }) +}) diff --git a/src/core/tools/__tests__/useMcpToolTool.spec.ts b/src/core/tools/__tests__/useMcpToolTool.spec.ts index 8738e059e552..240ef1471bff 100644 --- a/src/core/tools/__tests__/useMcpToolTool.spec.ts +++ b/src/core/tools/__tests__/useMcpToolTool.spec.ts @@ -7,7 +7,7 @@ import { ToolUse } from "../../../shared/tools" // Mock dependencies vi.mock("../../prompts/responses", () => ({ formatResponse: { - toolResult: vi.fn((result: string) => `Tool result: ${result}`), + toolResult: vi.fn((result: string, images?: string[]) => `Tool result: ${result}`), toolError: vi.fn((error: string) => `Tool error: ${error}`), invalidMcpToolArgumentError: vi.fn((server: string, tool: string) => `Invalid args for ${server}:${tool}`), unknownMcpToolError: vi.fn((server: string, tool: string, availableTools: string[]) => { @@ -209,6 +209,7 @@ describe("useMcpToolTool", () => { callTool: vi.fn().mockResolvedValue(mockToolResult), }), postMessageToWebview: vi.fn(), + getValue: vi.fn().mockReturnValue(undefined), // Add getValue mock }) await useMcpToolTool( @@ -415,6 +416,7 @@ describe("useMcpToolTool", () => { callTool: vi.fn().mockResolvedValue(mockToolResult), }), postMessageToWebview: vi.fn(), + getValue: vi.fn().mockReturnValue(undefined), // Add getValue mock }) const block: ToolUse = { diff --git a/src/core/tools/mcpImageConstants.ts b/src/core/tools/mcpImageConstants.ts new file mode 100644 index 000000000000..8cdce835ae74 --- /dev/null +++ b/src/core/tools/mcpImageConstants.ts @@ -0,0 +1,94 @@ +/** + * Constants for MCP image handling + */ + +/** + * Supported image MIME types for MCP responses + */ +export const SUPPORTED_IMAGE_TYPES = [ + "image/png", + "image/jpeg", + "image/jpg", + "image/gif", + "image/webp", + "image/svg+xml", + "image/bmp", +] as const + +export type SupportedImageType = (typeof SUPPORTED_IMAGE_TYPES)[number] + +/** + * Default limits for MCP image handling + */ +export const DEFAULT_MCP_IMAGE_LIMITS = { + maxImagesPerResponse: 5, + maxImageSizeMB: 2, +} as const + +/** + * Check if a MIME type is supported for images + */ +export function isSupportedImageType(mimeType: string): mimeType is SupportedImageType { + return SUPPORTED_IMAGE_TYPES.includes(mimeType as SupportedImageType) +} + +/** + * Validate base64 image data + * @param base64Data The base64 string to validate + * @returns true if valid, false otherwise + */ +export function isValidBase64Image(base64Data: string): boolean { + // Check if it's a valid base64 string + const base64Regex = /^[A-Za-z0-9+/]*={0,2}$/ + + // Remove data URL prefix if present + const base64Only = base64Data.replace(/^data:image\/[a-z]+;base64,/, "") + + // Check basic format + if (!base64Regex.test(base64Only)) { + return false + } + + // Check if length is valid (must be multiple of 4) + if (base64Only.length % 4 !== 0) { + return false + } + + try { + // Try to decode to verify it's valid base64 + atob(base64Only) + return true + } catch { + return false + } +} + +/** + * Calculate the approximate size of a base64 image in bytes + * @param base64Data The base64 string + * @returns Size in bytes + */ +export function calculateBase64Size(base64Data: string): number { + // Remove data URL prefix if present + const base64Only = base64Data.replace(/^data:image\/[a-z]+;base64,/, "") + + // Base64 encoding increases size by ~33%, so we reverse that + // Every 4 base64 characters represent 3 bytes + const padding = (base64Only.match(/=/g) || []).length + return Math.floor((base64Only.length * 3) / 4) - padding +} + +/** + * Convert bytes to megabytes + */ +export function bytesToMB(bytes: number): number { + return bytes / (1024 * 1024) +} + +/** + * Extract MIME type from a data URL + */ +export function extractMimeType(dataUrl: string): string | null { + const match = dataUrl.match(/^data:([a-z]+\/[a-z+-]+);base64,/) + return match ? match[1] : null +} diff --git a/src/core/tools/useMcpToolTool.ts b/src/core/tools/useMcpToolTool.ts index 41697ab979b5..a03e6af9d3e4 100644 --- a/src/core/tools/useMcpToolTool.ts +++ b/src/core/tools/useMcpToolTool.ts @@ -4,6 +4,15 @@ import { formatResponse } from "../prompts/responses" import { ClineAskUseMcpServer } from "../../shared/ExtensionMessage" import { McpExecutionStatus } from "@roo-code/types" import { t } from "../../i18n" +import { + SUPPORTED_IMAGE_TYPES, + DEFAULT_MCP_IMAGE_LIMITS, + isSupportedImageType, + isValidBase64Image, + calculateBase64Size, + bytesToMB, + extractMimeType, +} from "./mcpImageConstants" interface McpToolParams { server_name?: string @@ -195,24 +204,87 @@ async function sendExecutionStatus(cline: Task, status: McpExecutionStatus): Pro }) } -function processToolContent(toolResult: any): string { +interface ProcessedContent { + text: string + images: string[] + errors: string[] +} + +function processToolContent( + toolResult: any, + maxImages: number = DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse, + maxSizeMB: number = DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB, +): ProcessedContent { + const result: ProcessedContent = { + text: "", + images: [], + errors: [], + } + if (!toolResult?.content || toolResult.content.length === 0) { - return "" + return result } - return toolResult.content - .map((item: any) => { - if (item.type === "text") { - return item.text + const textParts: string[] = [] + + for (const item of toolResult.content) { + if (item.type === "text") { + textParts.push(item.text) + } else if (item.type === "image") { + // Handle image content + const imageData = item.data || item.base64 + const mimeType = item.mimeType || extractMimeType(imageData) + + if (!imageData) { + result.errors.push("Image data is missing") + continue } - if (item.type === "resource") { - const { blob: _, ...rest } = item.resource - return JSON.stringify(rest, null, 2) + + // Check if we've reached the image limit + if (result.images.length >= maxImages) { + result.errors.push(`Maximum number of images (${maxImages}) exceeded`) + continue } - return "" - }) - .filter(Boolean) - .join("\n\n") + + // Validate MIME type + if (mimeType && !isSupportedImageType(mimeType)) { + result.errors.push( + `Unsupported image type: ${mimeType}. Supported types: ${SUPPORTED_IMAGE_TYPES.join(", ")}`, + ) + continue + } + + // Validate base64 data + if (!isValidBase64Image(imageData)) { + result.errors.push("Invalid or corrupted base64 image data") + continue + } + + // Check image size + const sizeBytes = calculateBase64Size(imageData) + const sizeMB = bytesToMB(sizeBytes) + + if (sizeMB > maxSizeMB) { + result.errors.push(`Image size (${sizeMB.toFixed(2)}MB) exceeds maximum allowed size (${maxSizeMB}MB)`) + continue + } + + // Add data URL prefix if not present + let fullImageData = imageData + if (!imageData.startsWith("data:")) { + const type = mimeType || "image/png" + fullImageData = `data:${type};base64,${imageData}` + } + + result.images.push(fullImageData) + } else if (item.type === "resource") { + const { blob: _, ...rest } = item.resource + textParts.push(JSON.stringify(rest, null, 2)) + } + } + + result.text = textParts.filter(Boolean).join("\n\n") + return result } async function executeToolAndProcessResult( @@ -233,21 +305,48 @@ async function executeToolAndProcessResult( toolName, }) - const toolResult = await cline.providerRef.deref()?.getMcpHub()?.callTool(serverName, toolName, parsedArguments) + // Get configuration for image limits + const provider = cline.providerRef.deref() + const maxImages = provider?.getValue("mcpMaxImagesPerResponse") ?? DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse + const maxSizeMB = provider?.getValue("mcpMaxImageSizeMB") ?? DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB + + const toolResult = await provider?.getMcpHub()?.callTool(serverName, toolName, parsedArguments) let toolResultPretty = "(No response)" + let images: string[] = [] if (toolResult) { - const outputText = processToolContent(toolResult) + const processedContent = processToolContent(toolResult, maxImages, maxSizeMB) + + // Log any errors encountered during processing + if (processedContent.errors.length > 0) { + console.warn("MCP image processing warnings:", processedContent.errors) + // Include errors in the response for transparency + const errorText = processedContent.errors.map((e) => `⚠️ ${e}`).join("\n") + processedContent.text = processedContent.text ? `${processedContent.text}\n\n${errorText}` : errorText + } - if (outputText) { - await sendExecutionStatus(cline, { - executionId, - status: "output", - response: outputText, - }) + if (processedContent.text || processedContent.images.length > 0) { + // Send text output first + if (processedContent.text) { + await sendExecutionStatus(cline, { + executionId, + status: "output", + response: processedContent.text, + }) + } + + // Prepare the complete response + toolResultPretty = (toolResult.isError ? "Error:\n" : "") + processedContent.text - toolResultPretty = (toolResult.isError ? "Error:\n" : "") + outputText + // Store images for later use + images = processedContent.images + + // Include image count in response if images are present + if (images.length > 0) { + const imageInfo = `\n\n📷 ${images.length} image${images.length > 1 ? "s" : ""} included in response` + toolResultPretty += imageInfo + } } // Send completion status @@ -256,6 +355,7 @@ async function executeToolAndProcessResult( status: toolResult.isError ? "error" : "completed", response: toolResultPretty, error: toolResult.isError ? "Error executing MCP tool" : undefined, + images: images.length > 0 ? images : undefined, }) } else { // Send error status if no result @@ -266,8 +366,14 @@ async function executeToolAndProcessResult( }) } - await cline.say("mcp_server_response", toolResultPretty) - pushToolResult(formatResponse.toolResult(toolResultPretty)) + // Include images in the response message + if (images.length > 0) { + await cline.say("mcp_server_response", toolResultPretty, images) + } else { + await cline.say("mcp_server_response", toolResultPretty) + } + + pushToolResult(formatResponse.toolResult(toolResultPretty, images)) } export async function useMcpToolTool( diff --git a/src/shared/combineCommandSequences.ts b/src/shared/combineCommandSequences.ts index 56b97a368e5c..972c2d9a73d9 100644 --- a/src/shared/combineCommandSequences.ts +++ b/src/shared/combineCommandSequences.ts @@ -38,11 +38,17 @@ export function combineCommandSequences(messages: ClineMessage[]): ClineMessage[ if (msg.type === "ask" && msg.ask === "use_mcp_server") { // Look ahead for MCP responses let responses: string[] = [] + let allImages: string[] = [] let j = i + 1 while (j < messages.length) { if (messages[j].say === "mcp_server_response") { responses.push(messages[j].text || "") + // Collect images from the response if present + const msgImages = messages[j].images + if (msgImages && msgImages.length > 0) { + allImages.push(...msgImages) + } processedIndices.add(j) j++ } else if (messages[j].type === "ask" && messages[j].ask === "use_mcp_server") { @@ -63,7 +69,13 @@ export function combineCommandSequences(messages: ClineMessage[]): ClineMessage[ // Stringify the updated JSON object const combinedText = JSON.stringify(jsonObj) - combinedMessages.set(msg.ts, { ...msg, text: combinedText }) + // Create the combined message with images if present + const combinedMessage: ClineMessage = { ...msg, text: combinedText } + if (allImages.length > 0) { + combinedMessage.images = allImages + } + + combinedMessages.set(msg.ts, combinedMessage) } else { // If there's no response, just keep the original message combinedMessages.set(msg.ts, { ...msg }) diff --git a/webview-ui/src/components/chat/McpExecution.tsx b/webview-ui/src/components/chat/McpExecution.tsx index a96f368a17ec..af16b1b224d8 100644 --- a/webview-ui/src/components/chat/McpExecution.tsx +++ b/webview-ui/src/components/chat/McpExecution.tsx @@ -11,6 +11,7 @@ import { Button } from "@src/components/ui" import CodeBlock from "../common/CodeBlock" import McpToolRow from "../mcp/McpToolRow" import { Markdown } from "./Markdown" +import ImageBlock from "../common/ImageBlock" interface McpExecutionProps { executionId: string @@ -48,6 +49,7 @@ export const McpExecution = ({ const [argumentsText, setArgumentsText] = useState(text || "") const [serverName, setServerName] = useState(initialServerName) const [toolName, setToolName] = useState(initialToolName) + const [images, setImages] = useState([]) // Only need expanded state for response section (like command output) const [isResponseExpanded, setIsResponseExpanded] = useState(false) @@ -143,6 +145,11 @@ export const McpExecution = ({ } else if (data.status === "completed" && data.response) { setResponseText(data.response) } + + // Handle images if present (only for output, completed, or error status) + if ("images" in data && data.images && data.images.length > 0) { + setImages(data.images) + } } } } catch (e) { @@ -280,6 +287,7 @@ export const McpExecution = ({ isJson={responseIsJson} hasArguments={!!(isArguments || useMcpServer?.arguments || argumentsText)} isPartial={status ? status.status !== "completed" : false} + images={images} /> @@ -294,12 +302,14 @@ const ResponseContainerInternal = ({ isJson, hasArguments, isPartial = false, + images = [], }: { isExpanded: boolean response: string isJson: boolean hasArguments?: boolean isPartial?: boolean + images?: string[] }) => { // Only render content when expanded to prevent performance issues with large responses if (!isExpanded || response.length === 0) { @@ -323,6 +333,13 @@ const ResponseContainerInternal = ({ ) : ( )} + {images.length > 0 && ( +
+ {images.map((image, index) => ( + + ))} +
+ )} ) }