diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 349e32ced3..208ba563c6 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -275,9 +275,15 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH * @param prompt The text prompt for image generation * @param model The model to use for generation * @param apiKey The OpenRouter API key (must be explicitly provided) + * @param inputImage Optional base64 encoded input image data URL * @returns The generated image data and format, or an error */ - async generateImage(prompt: string, model: string, apiKey: string): Promise { + async generateImage( + prompt: string, + model: string, + apiKey: string, + inputImage?: string, + ): Promise { if (!apiKey) { return { success: false, @@ -299,7 +305,20 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH messages: [ { role: "user", - content: prompt, + content: inputImage + ? [ + { + type: "text", + text: prompt, + }, + { + type: "image_url", + image_url: { + url: inputImage, + }, + }, + ] + : prompt, }, ], modalities: ["image", "text"], diff --git a/src/core/prompts/tools/generate-image.ts b/src/core/prompts/tools/generate-image.ts index 7869765228..458b7ae8cf 100644 --- a/src/core/prompts/tools/generate-image.ts +++ b/src/core/prompts/tools/generate-image.ts @@ -2,19 +2,35 @@ import { ToolArgs } from "./types" export function getGenerateImageDescription(args: ToolArgs): string { return `## generate_image -Description: Request to generate an image using AI models through OpenRouter API. This tool creates images from text prompts and saves them to the specified path. +Description: Request to generate or edit an image using AI models through OpenRouter API. This tool can create new images from text prompts or modify existing images based on your instructions. When an input image is provided, the AI will apply the requested edits, transformations, or enhancements to that image. Parameters: -- prompt: (required) The text prompt describing the image to generate -- path: (required) The file path where the generated image should be saved (relative to the current workspace directory ${args.cwd}). The tool will automatically add the appropriate image extension if not provided. +- prompt: (required) The text prompt describing what to generate or how to edit the image +- path: (required) The file path where the generated/edited image should be saved (relative to the current workspace directory ${args.cwd}). The tool will automatically add the appropriate image extension if not provided. +- image: (optional) The file path to an input image to edit or transform (relative to the current workspace directory ${args.cwd}). Supported formats: PNG, JPG, JPEG, GIF, WEBP. Usage: Your image description here path/to/save/image.png +path/to/input/image.jpg Example: Requesting to generate a sunset image A beautiful sunset over mountains with vibrant orange and purple colors images/sunset.png + + +Example: Editing an existing image + +Transform this image into a watercolor painting style +images/watercolor-output.png +images/original-photo.jpg + + +Example: Upscaling and enhancing an image + +Upscale this image to higher resolution, enhance details, improve clarity and sharpness while maintaining the original content and composition +images/enhanced-photo.png +images/low-res-photo.jpg ` } diff --git a/src/core/tools/__tests__/generateImageTool.test.ts b/src/core/tools/__tests__/generateImageTool.test.ts new file mode 100644 index 0000000000..ac7e122841 --- /dev/null +++ b/src/core/tools/__tests__/generateImageTool.test.ts @@ -0,0 +1,313 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { generateImageTool } from "../generateImageTool" +import { ToolUse } from "../../../shared/tools" +import { Task } from "../../task/Task" +import * as fs from "fs/promises" +import * as pathUtils from "../../../utils/pathUtils" +import * as fileUtils from "../../../utils/fs" +import { formatResponse } from "../../prompts/responses" +import { EXPERIMENT_IDS } from "../../../shared/experiments" +import { OpenRouterHandler } from "../../../api/providers/openrouter" + +// Mock dependencies +vi.mock("fs/promises") +vi.mock("../../../utils/pathUtils") +vi.mock("../../../utils/fs") +vi.mock("../../../utils/safeWriteJson") +vi.mock("../../../api/providers/openrouter") + +describe("generateImageTool", () => { + let mockCline: any + let mockAskApproval: any + let mockHandleError: any + let mockPushToolResult: any + let mockRemoveClosingTag: any + + beforeEach(() => { + vi.clearAllMocks() + + // Setup mock Cline instance + mockCline = { + cwd: "/test/workspace", + consecutiveMistakeCount: 0, + recordToolError: vi.fn(), + recordToolUsage: vi.fn(), + sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"), + say: vi.fn(), + rooIgnoreController: { + validateAccess: vi.fn().mockReturnValue(true), + }, + rooProtectedController: { + isWriteProtected: vi.fn().mockReturnValue(false), + }, + providerRef: { + deref: vi.fn().mockReturnValue({ + getState: vi.fn().mockResolvedValue({ + experiments: { + [EXPERIMENT_IDS.IMAGE_GENERATION]: true, + }, + apiConfiguration: { + openRouterImageGenerationSettings: { + openRouterApiKey: "test-api-key", + selectedModel: "google/gemini-2.5-flash-image-preview", + }, + }, + }), + }), + }, + fileContextTracker: { + trackFileContext: vi.fn(), + }, + didEditFile: false, + } + + mockAskApproval = vi.fn().mockResolvedValue(true) + mockHandleError = vi.fn() + mockPushToolResult = vi.fn() + mockRemoveClosingTag = vi.fn((tag, content) => content || "") + + // Mock file system operations + vi.mocked(fileUtils.fileExistsAtPath).mockResolvedValue(true) + vi.mocked(fs.readFile).mockResolvedValue(Buffer.from("fake-image-data")) + vi.mocked(fs.mkdir).mockResolvedValue(undefined) + vi.mocked(fs.writeFile).mockResolvedValue(undefined) + vi.mocked(pathUtils.isPathOutsideWorkspace).mockReturnValue(false) + }) + + describe("partial block handling", () => { + it("should return early when block is partial", async () => { + const partialBlock: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + prompt: "Generate a test image", + path: "test-image.png", + }, + partial: true, + } + + await generateImageTool( + mockCline as Task, + partialBlock, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + // Should not process anything when partial + expect(mockAskApproval).not.toHaveBeenCalled() + expect(mockPushToolResult).not.toHaveBeenCalled() + expect(mockCline.say).not.toHaveBeenCalled() + }) + + it("should return early when block is partial even with image parameter", async () => { + const partialBlock: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + prompt: "Upscale this image", + path: "upscaled-image.png", + image: "source-image.png", + }, + partial: true, + } + + await generateImageTool( + mockCline as Task, + partialBlock, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + // Should not process anything when partial + expect(mockAskApproval).not.toHaveBeenCalled() + expect(mockPushToolResult).not.toHaveBeenCalled() + expect(mockCline.say).not.toHaveBeenCalled() + expect(fs.readFile).not.toHaveBeenCalled() + }) + + it("should process when block is not partial", async () => { + const completeBlock: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + prompt: "Generate a test image", + path: "test-image.png", + }, + partial: false, + } + + // Mock the OpenRouterHandler generateImage method + const mockGenerateImage = vi.fn().mockResolvedValue({ + success: true, + imageData: "data:image/png;base64,fakebase64data", + }) + + vi.mocked(OpenRouterHandler).mockImplementation( + () => + ({ + generateImage: mockGenerateImage, + }) as any, + ) + + await generateImageTool( + mockCline as Task, + completeBlock, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + // Should process the complete block + expect(mockAskApproval).toHaveBeenCalled() + expect(mockGenerateImage).toHaveBeenCalled() + expect(mockPushToolResult).toHaveBeenCalled() + }) + }) + + describe("missing parameters", () => { + it("should handle missing prompt parameter", async () => { + const block: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + path: "test-image.png", + }, + partial: false, + } + + await generateImageTool( + mockCline as Task, + block, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + expect(mockCline.consecutiveMistakeCount).toBe(1) + expect(mockCline.recordToolError).toHaveBeenCalledWith("generate_image") + expect(mockCline.sayAndCreateMissingParamError).toHaveBeenCalledWith("generate_image", "prompt") + expect(mockPushToolResult).toHaveBeenCalledWith("Missing parameter error") + }) + + it("should handle missing path parameter", async () => { + const block: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + prompt: "Generate a test image", + }, + partial: false, + } + + await generateImageTool( + mockCline as Task, + block, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + expect(mockCline.consecutiveMistakeCount).toBe(1) + expect(mockCline.recordToolError).toHaveBeenCalledWith("generate_image") + expect(mockCline.sayAndCreateMissingParamError).toHaveBeenCalledWith("generate_image", "path") + expect(mockPushToolResult).toHaveBeenCalledWith("Missing parameter error") + }) + }) + + describe("experiment validation", () => { + it("should error when image generation experiment is disabled", async () => { + // Disable the experiment + mockCline.providerRef.deref().getState.mockResolvedValue({ + experiments: { + [EXPERIMENT_IDS.IMAGE_GENERATION]: false, + }, + }) + + const block: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + prompt: "Generate a test image", + path: "test-image.png", + }, + partial: false, + } + + await generateImageTool( + mockCline as Task, + block, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + expect(mockPushToolResult).toHaveBeenCalledWith( + formatResponse.toolError( + "Image generation is an experimental feature that must be enabled in settings. Please enable 'Image Generation' in the Experimental Settings section.", + ), + ) + }) + }) + + describe("input image validation", () => { + it("should handle non-existent input image", async () => { + vi.mocked(fileUtils.fileExistsAtPath).mockResolvedValue(false) + + const block: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + prompt: "Upscale this image", + path: "upscaled.png", + image: "non-existent.png", + }, + partial: false, + } + + await generateImageTool( + mockCline as Task, + block, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + expect(mockCline.say).toHaveBeenCalledWith("error", expect.stringContaining("Input image not found")) + expect(mockPushToolResult).toHaveBeenCalledWith(expect.stringContaining("Input image not found")) + }) + + it("should handle unsupported image format", async () => { + const block: ToolUse = { + type: "tool_use", + name: "generate_image", + params: { + prompt: "Upscale this image", + path: "upscaled.png", + image: "test.bmp", // Unsupported format + }, + partial: false, + } + + await generateImageTool( + mockCline as Task, + block, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + expect(mockCline.say).toHaveBeenCalledWith("error", expect.stringContaining("Unsupported image format")) + expect(mockPushToolResult).toHaveBeenCalledWith(expect.stringContaining("Unsupported image format")) + }) + }) +}) diff --git a/src/core/tools/generateImageTool.ts b/src/core/tools/generateImageTool.ts index f3ffeb55ca..4bb67d629c 100644 --- a/src/core/tools/generateImageTool.ts +++ b/src/core/tools/generateImageTool.ts @@ -24,6 +24,7 @@ export async function generateImageTool( ) { const prompt: string | undefined = block.params.prompt const relPath: string | undefined = block.params.path + const inputImagePath: string | undefined = block.params.image // Check if the experiment is enabled const provider = cline.providerRef.deref() @@ -39,8 +40,7 @@ export async function generateImageTool( return } - if (block.partial && (!prompt || !relPath)) { - // Wait for complete parameters + if (block.partial) { return } @@ -66,6 +66,66 @@ export async function generateImageTool( return } + // If input image is provided, validate it exists and can be read + let inputImageData: string | undefined + if (inputImagePath) { + const inputImageFullPath = path.resolve(cline.cwd, inputImagePath) + + // Check if input image exists + const inputImageExists = await fileExistsAtPath(inputImageFullPath) + if (!inputImageExists) { + await cline.say("error", `Input image not found: ${getReadablePath(cline.cwd, inputImagePath)}`) + pushToolResult( + formatResponse.toolError(`Input image not found: ${getReadablePath(cline.cwd, inputImagePath)}`), + ) + return + } + + // Validate input image access permissions + const inputImageAccessAllowed = cline.rooIgnoreController?.validateAccess(inputImagePath) + if (!inputImageAccessAllowed) { + await cline.say("rooignore_error", inputImagePath) + pushToolResult(formatResponse.toolError(formatResponse.rooIgnoreError(inputImagePath))) + return + } + + // Read the input image file + try { + const imageBuffer = await fs.readFile(inputImageFullPath) + const imageExtension = path.extname(inputImageFullPath).toLowerCase().replace(".", "") + + // Validate image format + const supportedFormats = ["png", "jpg", "jpeg", "gif", "webp"] + if (!supportedFormats.includes(imageExtension)) { + await cline.say( + "error", + `Unsupported image format: ${imageExtension}. Supported formats: ${supportedFormats.join(", ")}`, + ) + pushToolResult( + formatResponse.toolError( + `Unsupported image format: ${imageExtension}. Supported formats: ${supportedFormats.join(", ")}`, + ), + ) + return + } + + // Convert to base64 data URL + const mimeType = imageExtension === "jpg" ? "jpeg" : imageExtension + inputImageData = `data:image/${mimeType};base64,${imageBuffer.toString("base64")}` + } catch (error) { + await cline.say( + "error", + `Failed to read input image: ${error instanceof Error ? error.message : "Unknown error"}`, + ) + pushToolResult( + formatResponse.toolError( + `Failed to read input image: ${error instanceof Error ? error.message : "Unknown error"}`, + ), + ) + return + } + } + // Check if file is write-protected const isWriteProtected = cline.rooProtectedController?.isWriteProtected(relPath) || false @@ -110,6 +170,7 @@ export async function generateImageTool( const approvalMessage = JSON.stringify({ ...sharedMessageProps, content: prompt, + ...(inputImagePath && { inputImage: getReadablePath(cline.cwd, inputImagePath) }), }) const didApprove = await askApproval("tool", approvalMessage, undefined, isWriteProtected) @@ -121,8 +182,13 @@ export async function generateImageTool( // Create a temporary OpenRouter handler with minimal options const openRouterHandler = new OpenRouterHandler({} as any) - // Call the generateImage method with the explicit API key - const result = await openRouterHandler.generateImage(prompt, selectedModel, openRouterApiKey) + // Call the generateImage method with the explicit API key and optional input image + const result = await openRouterHandler.generateImage( + prompt, + selectedModel, + openRouterApiKey, + inputImageData, + ) if (!result.success) { await cline.say("error", result.error || "Failed to generate image") diff --git a/src/shared/tools.ts b/src/shared/tools.ts index f15e8ef4c9..8a8776764e 100644 --- a/src/shared/tools.ts +++ b/src/shared/tools.ts @@ -66,6 +66,7 @@ export const toolParamNames = [ "args", "todos", "prompt", + "image", ] as const export type ToolParamName = (typeof toolParamNames)[number] @@ -167,7 +168,7 @@ export interface SearchAndReplaceToolUse extends ToolUse { export interface GenerateImageToolUse extends ToolUse { name: "generate_image" - params: Partial, "prompt" | "path">> + params: Partial, "prompt" | "path" | "image">> } // Define tool group configuration