diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 207c60a524..c476cf3326 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -77,6 +77,14 @@ const baseProviderSettingsSchema = z.object({ reasoningEffort: reasoningEffortsSchema.optional(), modelMaxTokens: z.number().optional(), modelMaxThinkingTokens: z.number().optional(), + + // External MCP server settings for enhance prompt + enhancePrompt: z + .object({ + useExternalServer: z.boolean().optional(), + endpoint: z.string().url().optional(), + }) + .optional(), }) // Several of the providers share common model config properties. diff --git a/src/core/webview/__tests__/messageEnhancer.test.ts b/src/core/webview/__tests__/messageEnhancer.test.ts index f6f6b44e1d..5b8dfca043 100644 --- a/src/core/webview/__tests__/messageEnhancer.test.ts +++ b/src/core/webview/__tests__/messageEnhancer.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" import { MessageEnhancer } from "../messageEnhancer" -import { ProviderSettings, ClineMessage } from "@roo-code/types" +import { ProviderSettings, ClineMessage, getModelId } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" import * as singleCompletionHandlerModule from "../../../utils/single-completion-handler" import { ProviderSettingsManager } from "../../config/ProviderSettingsManager" @@ -9,6 +9,9 @@ import { ProviderSettingsManager } from "../../config/ProviderSettingsManager" vi.mock("../../../utils/single-completion-handler") vi.mock("@roo-code/telemetry") +// Mock global fetch +global.fetch = vi.fn() + describe("MessageEnhancer", () => { let mockProviderSettingsManager: ProviderSettingsManager let mockSingleCompletionHandler: ReturnType @@ -254,6 +257,255 @@ describe("MessageEnhancer", () => { // Should not include task history section expect(calledPrompt).not.toContain("previous conversation context") }) + + describe("External MCP Server", () => { + beforeEach(() => { + vi.mocked(global.fetch).mockReset() + }) + + it("should use external MCP server when enabled", async () => { + const mockResponse = { + enhancedPrompt: "Enhanced via external server", + } + vi.mocked(global.fetch).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue(mockResponse), + } as any) + + const configWithExternalServer: ProviderSettings = { + ...mockApiConfiguration, + enhancePrompt: { + useExternalServer: true, + endpoint: "http://localhost:8000/enhance", + }, + } + + const result = await MessageEnhancer.enhanceMessage({ + text: "Test prompt", + apiConfiguration: configWithExternalServer, + listApiConfigMeta: mockListApiConfigMeta, + providerSettingsManager: mockProviderSettingsManager, + }) + + expect(result.success).toBe(true) + expect(result.enhancedText).toBe("Enhanced via external server") + expect(global.fetch).toHaveBeenCalledWith("http://localhost:8000/enhance", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + prompt: "Test prompt", + context: [], + model: "gpt-4", + }), + }) + // Should not call internal enhancement + expect(mockSingleCompletionHandler).not.toHaveBeenCalled() + }) + + it("should include context messages when using external server", async () => { + const mockResponse = { + enhancedPrompt: "Enhanced with context", + } + vi.mocked(global.fetch).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue(mockResponse), + } as any) + + const configWithExternalServer: ProviderSettings = { + ...mockApiConfiguration, + enhancePrompt: { + useExternalServer: true, + endpoint: "http://localhost:8000/enhance", + }, + } + + const mockClineMessages: ClineMessage[] = [ + { type: "ask", text: "User message", ts: 1000 }, + { type: "say", say: "text", text: "Assistant response", ts: 2000 }, + ] + + await MessageEnhancer.enhanceMessage({ + text: "Test prompt", + apiConfiguration: configWithExternalServer, + listApiConfigMeta: mockListApiConfigMeta, + currentClineMessages: mockClineMessages, + providerSettingsManager: mockProviderSettingsManager, + }) + + expect(global.fetch).toHaveBeenCalledWith("http://localhost:8000/enhance", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + prompt: "Test prompt", + context: [ + { role: "user", content: "User message" }, + { role: "assistant", content: "Assistant response" }, + ], + model: "gpt-4", + }), + }) + }) + + it("should fall back to internal enhancement when external server fails", async () => { + vi.mocked(global.fetch).mockRejectedValue(new Error("Network error")) + + const configWithExternalServer: ProviderSettings = { + ...mockApiConfiguration, + enhancePrompt: { + useExternalServer: true, + endpoint: "http://localhost:8000/enhance", + }, + } + + const result = await MessageEnhancer.enhanceMessage({ + text: "Test prompt", + apiConfiguration: configWithExternalServer, + listApiConfigMeta: mockListApiConfigMeta, + providerSettingsManager: mockProviderSettingsManager, + }) + + expect(result.success).toBe(true) + expect(result.enhancedText).toBe("Enhanced prompt text") + expect(mockSingleCompletionHandler).toHaveBeenCalled() + }) + + it("should fall back when external server returns non-ok status", async () => { + vi.mocked(global.fetch).mockResolvedValue({ + ok: false, + status: 500, + statusText: "Internal Server Error", + } as any) + + const configWithExternalServer: ProviderSettings = { + ...mockApiConfiguration, + enhancePrompt: { + useExternalServer: true, + endpoint: "http://localhost:8000/enhance", + }, + } + + const result = await MessageEnhancer.enhanceMessage({ + text: "Test prompt", + apiConfiguration: configWithExternalServer, + listApiConfigMeta: mockListApiConfigMeta, + providerSettingsManager: mockProviderSettingsManager, + }) + + expect(result.success).toBe(true) + expect(result.enhancedText).toBe("Enhanced prompt text") + expect(mockSingleCompletionHandler).toHaveBeenCalled() + }) + + it("should fall back when external server response is missing enhancedPrompt", async () => { + vi.mocked(global.fetch).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ wrongField: "value" }), + } as any) + + const configWithExternalServer: ProviderSettings = { + ...mockApiConfiguration, + enhancePrompt: { + useExternalServer: true, + endpoint: "http://localhost:8000/enhance", + }, + } + + const result = await MessageEnhancer.enhanceMessage({ + text: "Test prompt", + apiConfiguration: configWithExternalServer, + listApiConfigMeta: mockListApiConfigMeta, + providerSettingsManager: mockProviderSettingsManager, + }) + + expect(result.success).toBe(true) + expect(result.enhancedText).toBe("Enhanced prompt text") + expect(mockSingleCompletionHandler).toHaveBeenCalled() + }) + + it("should not use external server when useExternalServer is false", async () => { + const configWithDisabledExternalServer: ProviderSettings = { + ...mockApiConfiguration, + enhancePrompt: { + useExternalServer: false, + endpoint: "http://localhost:8000/enhance", + }, + } + + await MessageEnhancer.enhanceMessage({ + text: "Test prompt", + apiConfiguration: configWithDisabledExternalServer, + listApiConfigMeta: mockListApiConfigMeta, + providerSettingsManager: mockProviderSettingsManager, + }) + + expect(global.fetch).not.toHaveBeenCalled() + expect(mockSingleCompletionHandler).toHaveBeenCalled() + }) + + it("should not use external server when endpoint is missing", async () => { + const configWithoutEndpoint: ProviderSettings = { + ...mockApiConfiguration, + enhancePrompt: { + useExternalServer: true, + }, + } + + await MessageEnhancer.enhanceMessage({ + text: "Test prompt", + apiConfiguration: configWithoutEndpoint, + listApiConfigMeta: mockListApiConfigMeta, + providerSettingsManager: mockProviderSettingsManager, + }) + + expect(global.fetch).not.toHaveBeenCalled() + expect(mockSingleCompletionHandler).toHaveBeenCalled() + }) + + it("should handle different model ID fields correctly", async () => { + vi.mocked(global.fetch).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ enhancedPrompt: "Enhanced" }), + } as any) + + // Test with different provider configurations + const configs = [ + { + ...mockApiConfiguration, + apiModelId: "model-1", + enhancePrompt: { useExternalServer: true, endpoint: "http://localhost:8000/enhance" }, + }, + { + apiProvider: "ollama" as const, + ollamaModelId: "llama2", + enhancePrompt: { useExternalServer: true, endpoint: "http://localhost:8000/enhance" }, + }, + { + apiProvider: "openrouter" as const, + openRouterModelId: "gpt-4", + enhancePrompt: { useExternalServer: true, endpoint: "http://localhost:8000/enhance" }, + }, + ] + + for (const config of configs) { + vi.mocked(global.fetch).mockClear() + + await MessageEnhancer.enhanceMessage({ + text: "Test", + apiConfiguration: config, + listApiConfigMeta: mockListApiConfigMeta, + providerSettingsManager: mockProviderSettingsManager, + }) + + const expectedModel = getModelId(config) || "unknown" + expect(global.fetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + body: expect.stringContaining(`"model":"${expectedModel}"`), + }), + ) + } + }) + }) }) describe("captureTelemetry", () => { diff --git a/src/core/webview/messageEnhancer.ts b/src/core/webview/messageEnhancer.ts index 89df7b5b59..b846bdd318 100644 --- a/src/core/webview/messageEnhancer.ts +++ b/src/core/webview/messageEnhancer.ts @@ -1,4 +1,4 @@ -import { ProviderSettings, ClineMessage, GlobalState, TelemetryEventName } from "@roo-code/types" +import { ProviderSettings, ClineMessage, GlobalState, TelemetryEventName, getModelId } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" import { supportPrompt } from "../../shared/support-prompt" import { singleCompletionHandler } from "../../utils/single-completion-handler" @@ -58,6 +58,55 @@ export class MessageEnhancer { } } + // Check if external MCP server is enabled + if (configToUse.enhancePrompt?.useExternalServer && configToUse.enhancePrompt?.endpoint) { + try { + // Prepare context messages for external server + const contextMessages = + currentClineMessages + ?.filter((msg) => { + if (msg.type === "ask" && msg.text) return true + if (msg.type === "say" && msg.say === "text" && msg.text) return true + return false + }) + .slice(-10) + .map((msg) => ({ + role: msg.type === "ask" ? "user" : "assistant", + content: msg.text || "", + })) || [] + + // Make request to external MCP server + const response = await fetch(configToUse.enhancePrompt.endpoint, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + prompt: text, + context: contextMessages, + model: getModelId(configToUse) || "unknown", + }), + }) + + if (!response.ok) { + throw new Error(`External server returned ${response.status}: ${response.statusText}`) + } + + const result = await response.json() + + if (result.enhancedPrompt) { + return { + success: true, + enhancedText: result.enhancedPrompt, + } + } else { + throw new Error("External server response missing 'enhancedPrompt' field") + } + } catch (err) { + console.error("Failed to enhance prompt via external server:", err) + // Fallback to default logic + } + } + + // Default internal enhancement logic // Prepare the prompt to enhance let promptToEnhance = text