Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ export const globalSettingsSchema = z.object({

mcpEnabled: z.boolean().optional(),
enableMcpServerCreation: z.boolean().optional(),
mcpMaxImagesPerResponse: z.number().min(0).max(20).optional(),
mcpMaxImageSizeMB: z.number().min(0.1).max(10).optional(),

mode: z.string().optional(),
modeApiConfigs: z.record(z.string(), z.string()).optional(),
Expand Down Expand Up @@ -314,6 +316,8 @@ export const EVALS_SETTINGS: RooCodeSettings = {
telemetrySetting: "enabled",

mcpEnabled: false,
mcpMaxImagesPerResponse: 5,
mcpMaxImageSizeMB: 2,

mode: "code", // "architect",

Expand Down
4 changes: 4 additions & 0 deletions packages/types/src/mcp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@ export const mcpExecutionStatusSchema = z.discriminatedUnion("status", [
executionId: z.string(),
status: z.literal("output"),
response: z.string(),
images: z.array(z.string()).optional(),
}),
z.object({
executionId: z.string(),
status: z.literal("completed"),
response: z.string().optional(),
images: z.array(z.string()).optional(),
}),
z.object({
executionId: z.string(),
status: z.literal("error"),
error: z.string().optional(),
response: z.string().optional(),
images: z.array(z.string()).optional(),
}),
])

Expand Down
223 changes: 223 additions & 0 deletions src/core/tools/__tests__/mcpImageHandling.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import {
SUPPORTED_IMAGE_TYPES,
DEFAULT_MCP_IMAGE_LIMITS,
isSupportedImageType,
isValidBase64Image,
calculateBase64Size,
bytesToMB,
extractMimeType,
} from "../mcpImageConstants"

describe("MCP Image Constants", () => {
describe("SUPPORTED_IMAGE_TYPES", () => {
it("should include all common image MIME types", () => {
expect(SUPPORTED_IMAGE_TYPES).toContain("image/png")
expect(SUPPORTED_IMAGE_TYPES).toContain("image/jpeg")
expect(SUPPORTED_IMAGE_TYPES).toContain("image/gif")
expect(SUPPORTED_IMAGE_TYPES).toContain("image/webp")
expect(SUPPORTED_IMAGE_TYPES).toContain("image/svg+xml")
expect(SUPPORTED_IMAGE_TYPES).toContain("image/bmp")
})
})

describe("DEFAULT_MCP_IMAGE_LIMITS", () => {
it("should have reasonable default limits", () => {
expect(DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse).toBe(5)
expect(DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB).toBe(2)
})
})

describe("isSupportedImageType", () => {
it("should return true for supported MIME types", () => {
expect(isSupportedImageType("image/png")).toBe(true)
expect(isSupportedImageType("image/jpeg")).toBe(true)
expect(isSupportedImageType("image/gif")).toBe(true)
expect(isSupportedImageType("image/webp")).toBe(true)
expect(isSupportedImageType("image/svg+xml")).toBe(true)
expect(isSupportedImageType("image/bmp")).toBe(true)
})

it("should return false for unsupported MIME types", () => {
expect(isSupportedImageType("image/tiff")).toBe(false)
expect(isSupportedImageType("application/pdf")).toBe(false)
expect(isSupportedImageType("text/plain")).toBe(false)
expect(isSupportedImageType("video/mp4")).toBe(false)
})

it("should handle edge cases", () => {
expect(isSupportedImageType("")).toBe(false)
expect(isSupportedImageType("IMAGE/PNG")).toBe(false) // Case sensitive
})
})

describe("isValidBase64Image", () => {
// Mock atob for Node.js environment
beforeEach(() => {
if (typeof global.atob === "undefined") {
global.atob = (str: string) => Buffer.from(str, "base64").toString("binary")
}
})

it("should validate correct base64 image data", () => {
// Valid PNG data URL
const validPngDataUrl =
""
expect(isValidBase64Image(validPngDataUrl)).toBe(true)

// Valid base64 without data URL prefix
const validBase64 =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
expect(isValidBase64Image(validBase64)).toBe(true)
})

it("should reject invalid base64 data", () => {
// Invalid base64 characters
expect(isValidBase64Image("!@#$%^&*()")).toBe(false)

// Not base64 at all
expect(isValidBase64Image("not base64 data")).toBe(false)

// Empty string - returns true because empty base64 is technically valid
expect(isValidBase64Image("")).toBe(true)

// Malformed data URL
expect(isValidBase64Image("data:image/png;base64,!!!invalid!!!")).toBe(false)
})

it("should handle corrupted base64 data", () => {
// Missing padding - still valid base64
const corruptedBase64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ"
// This is actually valid base64 (padding is optional in some implementations)
expect(isValidBase64Image(corruptedBase64)).toBe(true)

// Truncated data - not multiple of 4
const truncatedBase64 = "iVBORw0KGg="
expect(isValidBase64Image(truncatedBase64)).toBe(false)
})
})

describe("calculateBase64Size", () => {
it("should calculate size for base64 string without data URL prefix", () => {
// Base64 string of known size (approximately)
const base64 =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
const size = calculateBase64Size(base64)
// Base64 encoding increases size by ~33%, so we expect around 75% of string length
expect(size).toBeGreaterThan(0)
expect(size).toBeLessThan(base64.length)
})

it("should calculate size for data URL", () => {
const dataUrl =
""
const size = calculateBase64Size(dataUrl)
expect(size).toBeGreaterThan(0)
// Should only count the base64 part, not the prefix
expect(size).toBeLessThan(dataUrl.length)
})

it("should handle edge cases", () => {
expect(calculateBase64Size("")).toBe(0)
expect(calculateBase64Size("data:image/png;base64,")).toBe(0)
})
})

describe("bytesToMB", () => {
it("should convert bytes to megabytes correctly", () => {
expect(bytesToMB(1048576)).toBe(1) // 1 MB
expect(bytesToMB(2097152)).toBe(2) // 2 MB
expect(bytesToMB(524288)).toBe(0.5) // 0.5 MB
expect(bytesToMB(0)).toBe(0)
})

it("should handle decimal values", () => {
expect(bytesToMB(1572864)).toBeCloseTo(1.5, 2) // 1.5 MB
expect(bytesToMB(3145728)).toBeCloseTo(3, 2) // 3 MB
})
})

describe("extractMimeType", () => {
it("should extract MIME type from data URL", () => {
expect(extractMimeType("...")).toBe("image/png")
expect(extractMimeType("...")).toBe("image/jpeg")
expect(extractMimeType("...")).toBe("image/gif")
})

it("should return null for non-data URLs", () => {
expect(extractMimeType("iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB")).toBeNull()
expect(extractMimeType("not a data url")).toBeNull()
expect(extractMimeType("")).toBeNull()
})

it("should handle malformed data URLs", () => {
expect(extractMimeType("data:;base64,iVBORw0KG...")).toBeNull()
expect(extractMimeType(""

// Extract MIME type
const mimeType = extractMimeType(validImage)
expect(mimeType).toBe("image/png")

// Check if supported
expect(isSupportedImageType(mimeType!)).toBe(true)

// Validate base64
expect(isValidBase64Image(validImage)).toBe(true)

// Check size
const sizeBytes = calculateBase64Size(validImage)
const sizeMB = bytesToMB(sizeBytes)
expect(sizeMB).toBeLessThan(DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB)
})

it("should reject an oversized image", () => {
// Create a large base64 string (simulating > 2MB)
const largeBase64 = "data:image/png;base64," + "A".repeat(3 * 1024 * 1024) // ~3MB of 'A's

const sizeBytes = calculateBase64Size(largeBase64)
const sizeMB = bytesToMB(sizeBytes)

expect(sizeMB).toBeGreaterThan(DEFAULT_MCP_IMAGE_LIMITS.maxImageSizeMB)
})

it("should handle multiple images respecting count limits", () => {
const images = [
"",
"",
"",
"",
"",
"", // 6th image
]

const processedImages: string[] = []
const errors: string[] = []

for (const image of images) {
if (processedImages.length >= DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse) {
errors.push(`Maximum number of images (${DEFAULT_MCP_IMAGE_LIMITS.maxImagesPerResponse}) exceeded`)
continue
}
processedImages.push(image)
}

expect(processedImages).toHaveLength(5)
expect(errors).toHaveLength(1)
expect(errors[0]).toContain("Maximum number of images")
})
})
4 changes: 3 additions & 1 deletion src/core/tools/__tests__/useMcpToolTool.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { ToolUse } from "../../../shared/tools"
// Mock dependencies
vi.mock("../../prompts/responses", () => ({
formatResponse: {
toolResult: vi.fn((result: string) => `Tool result: ${result}`),
toolResult: vi.fn((result: string, images?: string[]) => `Tool result: ${result}`),
toolError: vi.fn((error: string) => `Tool error: ${error}`),
invalidMcpToolArgumentError: vi.fn((server: string, tool: string) => `Invalid args for ${server}:${tool}`),
unknownMcpToolError: vi.fn((server: string, tool: string, availableTools: string[]) => {
Expand Down Expand Up @@ -209,6 +209,7 @@ describe("useMcpToolTool", () => {
callTool: vi.fn().mockResolvedValue(mockToolResult),
}),
postMessageToWebview: vi.fn(),
getValue: vi.fn().mockReturnValue(undefined), // Add getValue mock
})

await useMcpToolTool(
Expand Down Expand Up @@ -415,6 +416,7 @@ describe("useMcpToolTool", () => {
callTool: vi.fn().mockResolvedValue(mockToolResult),
}),
postMessageToWebview: vi.fn(),
getValue: vi.fn().mockReturnValue(undefined), // Add getValue mock
})

const block: ToolUse = {
Expand Down
94 changes: 94 additions & 0 deletions src/core/tools/mcpImageConstants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/**
* Constants for MCP image handling
*/

/**
* Supported image MIME types for MCP responses
*/
export const SUPPORTED_IMAGE_TYPES = [
"image/png",
"image/jpeg",
"image/jpg",
"image/gif",
"image/webp",
"image/svg+xml",
"image/bmp",
] as const

export type SupportedImageType = (typeof SUPPORTED_IMAGE_TYPES)[number]

/**
* Default limits for MCP image handling
*/
export const DEFAULT_MCP_IMAGE_LIMITS = {
maxImagesPerResponse: 5,
maxImageSizeMB: 2,
} as const

/**
* Check if a MIME type is supported for images
*/
export function isSupportedImageType(mimeType: string): mimeType is SupportedImageType {
return SUPPORTED_IMAGE_TYPES.includes(mimeType as SupportedImageType)
}

/**
* Validate base64 image data
* @param base64Data The base64 string to validate
* @returns true if valid, false otherwise
*/
export function isValidBase64Image(base64Data: string): boolean {
// Check if it's a valid base64 string
const base64Regex = /^[A-Za-z0-9+/]*={0,2}$/

// Remove data URL prefix if present
const base64Only = base64Data.replace(/^data:image\/[a-z]+;base64,/, "")

// Check basic format
if (!base64Regex.test(base64Only)) {
return false
}

// Check if length is valid (must be multiple of 4)
if (base64Only.length % 4 !== 0) {
return false
}

try {
// Try to decode to verify it's valid base64
atob(base64Only)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the atob function guaranteed to be available in all environments where this code runs? I see it's mocked in tests, but we might want to add a fallback or check for its existence to prevent runtime errors in edge cases:

Suggested change
atob(base64Only)
try {
// Try to decode to verify it's valid base64
if (typeof atob !== 'undefined') {
atob(base64Only)
} else {
// Fallback for environments without atob
Buffer.from(base64Only, 'base64').toString('binary')
}
return true
} catch {
return false
}

return true
} catch {
return false
}
}

/**
* Calculate the approximate size of a base64 image in bytes
* @param base64Data The base64 string
* @returns Size in bytes
*/
export function calculateBase64Size(base64Data: string): number {
// Remove data URL prefix if present
const base64Only = base64Data.replace(/^data:image\/[a-z]+;base64,/, "")

// Base64 encoding increases size by ~33%, so we reverse that
// Every 4 base64 characters represent 3 bytes
const padding = (base64Only.match(/=/g) || []).length
return Math.floor((base64Only.length * 3) / 4) - padding
}

/**
* Convert bytes to megabytes
*/
export function bytesToMB(bytes: number): number {
return bytes / (1024 * 1024)
}

/**
* Extract MIME type from a data URL
*/
export function extractMimeType(dataUrl: string): string | null {
const match = dataUrl.match(/^data:([a-z]+\/[a-z+-]+);base64,/)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting this regex pattern to a named constant for better maintainability:

Suggested change
const match = dataUrl.match(/^data:([a-z]+\/[a-z+-]+);base64,/)
const DATA_URL_MIME_PATTERN = /^data:([a-z]+\/[a-z+-]+);base64,/
export function extractMimeType(dataUrl: string): string | null {
const match = dataUrl.match(DATA_URL_MIME_PATTERN)
return match ? match[1] : null
}

return match ? match[1] : null
}
Loading
Loading