diff --git a/.eslintrc.json b/.eslintrc.json index e967b58a03..f39899d0c8 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -15,6 +15,8 @@ } ], "@typescript-eslint/semi": "off", + "no-unused-vars": "off", + "@typescript-eslint/no-unused-vars": ["error", { "varsIgnorePattern": "^_", "argsIgnorePattern": "^_" }], "eqeqeq": "warn", "no-throw-literal": "warn", "semi": "off" diff --git a/e2e/src/suite/index.ts b/e2e/src/suite/index.ts index d371a0f4c8..1a3e265662 100644 --- a/e2e/src/suite/index.ts +++ b/e2e/src/suite/index.ts @@ -24,15 +24,6 @@ export async function run() { apiProvider: "openrouter" as const, openRouterApiKey: process.env.OPENROUTER_API_KEY!, openRouterModelId: "google/gemini-2.0-flash-001", - openRouterModelInfo: { - maxTokens: 8192, - contextWindow: 1000000, - supportsImages: true, - supportsPromptCache: false, - inputPrice: 0.1, - outputPrice: 0.4, - thinking: false, - }, }) await vscode.commands.executeCommand("roo-cline.SidebarProvider.focus") diff --git a/e2e/src/suite/utils.ts b/e2e/src/suite/utils.ts index 3437c74e55..784d299820 100644 --- a/e2e/src/suite/utils.ts +++ b/e2e/src/suite/utils.ts @@ -1,5 +1,3 @@ -import * as vscode from "vscode" - import type { RooCodeAPI } from "../../../src/exports/roo-code" type WaitForOptions = { diff --git a/evals/apps/web/src/app/runs/new/new-run.tsx b/evals/apps/web/src/app/runs/new/new-run.tsx index 38759eaa87..71b7422ff3 100644 --- a/evals/apps/web/src/app/runs/new/new-run.tsx +++ b/evals/apps/web/src/app/runs/new/new-run.tsx @@ -94,8 +94,7 @@ export function NewRun() { } const openRouterModelId = openRouterModel.id - const openRouterModelInfo = openRouterModel.modelInfo - values.settings = { ...(values.settings || {}), openRouterModelId, openRouterModelInfo } + values.settings = { ...(values.settings || {}), openRouterModelId } } const { id } = await createRun(values) diff --git a/evals/packages/types/src/roo-code.ts b/evals/packages/types/src/roo-code.ts index 2c467f6a9b..3984d54882 100644 --- a/evals/packages/types/src/roo-code.ts +++ b/evals/packages/types/src/roo-code.ts @@ -304,12 +304,10 @@ export const providerSettingsSchema = z.object({ anthropicUseAuthToken: z.boolean().optional(), // Glama glamaModelId: z.string().optional(), - glamaModelInfo: modelInfoSchema.optional(), glamaApiKey: z.string().optional(), // OpenRouter openRouterApiKey: z.string().optional(), openRouterModelId: z.string().optional(), - openRouterModelInfo: modelInfoSchema.optional(), openRouterBaseUrl: z.string().optional(), openRouterSpecificProvider: z.string().optional(), openRouterUseMiddleOutTransform: z.boolean().optional(), @@ -371,11 +369,9 @@ export const providerSettingsSchema = z.object({ // Unbound unboundApiKey: z.string().optional(), unboundModelId: z.string().optional(), - unboundModelInfo: modelInfoSchema.optional(), // Requesty requestyApiKey: z.string().optional(), requestyModelId: z.string().optional(), - requestyModelInfo: modelInfoSchema.optional(), // Claude 3.7 Sonnet Thinking modelMaxTokens: z.number().optional(), // Currently only used by Anthropic hybrid thinking models. modelMaxThinkingTokens: z.number().optional(), // Currently only used by Anthropic hybrid thinking models. @@ -401,12 +397,10 @@ const providerSettingsRecord: ProviderSettingsRecord = { anthropicUseAuthToken: undefined, // Glama glamaModelId: undefined, - glamaModelInfo: undefined, glamaApiKey: undefined, // OpenRouter openRouterApiKey: undefined, openRouterModelId: undefined, - openRouterModelInfo: undefined, openRouterBaseUrl: undefined, openRouterSpecificProvider: undefined, openRouterUseMiddleOutTransform: undefined, @@ -460,11 +454,9 @@ const providerSettingsRecord: ProviderSettingsRecord = { // Unbound unboundApiKey: undefined, unboundModelId: undefined, - unboundModelInfo: undefined, // Requesty requestyApiKey: undefined, requestyModelId: undefined, - requestyModelInfo: undefined, // Claude 3.7 Sonnet Thinking modelMaxTokens: undefined, modelMaxThinkingTokens: undefined, diff --git a/src/__mocks__/McpHub.ts b/src/__mocks__/McpHub.ts index 7aef91b07b..108d6a6ca9 100644 --- a/src/__mocks__/McpHub.ts +++ b/src/__mocks__/McpHub.ts @@ -7,11 +7,11 @@ export class McpHub { this.callTool = jest.fn() } - async toggleToolAlwaysAllow(serverName: string, toolName: string, shouldAllow: boolean): Promise { + async toggleToolAlwaysAllow(_serverName: string, _toolName: string, _shouldAllow: boolean): Promise { return Promise.resolve() } - async callTool(serverName: string, toolName: string, toolArguments?: Record): Promise { + async callTool(_serverName: string, _toolName: string, _toolArguments?: Record): Promise { return Promise.resolve({ result: "success" }) } } diff --git a/src/__tests__/migrateSettings.test.ts b/src/__tests__/migrateSettings.test.ts index 107f310639..9bea4aa9b9 100644 --- a/src/__tests__/migrateSettings.test.ts +++ b/src/__tests__/migrateSettings.test.ts @@ -10,7 +10,6 @@ jest.mock("vscode") jest.mock("fs/promises") jest.mock("fs") jest.mock("../utils/fs") -// We're testing the real migrateSettings function describe("Settings Migration", () => { let mockContext: vscode.ExtensionContext @@ -52,8 +51,6 @@ describe("Settings Migration", () => { }) it("should migrate custom modes file if old file exists and new file doesn't", async () => { - const mockCustomModesContent = '{"customModes":[{"slug":"test-mode"}]}' as string - // Mock file existence checks ;(fileExistsAtPath as jest.Mock).mockImplementation(async (path: string) => { if (path === mockSettingsDir) return true @@ -69,8 +66,6 @@ describe("Settings Migration", () => { }) it("should migrate MCP settings file if old file exists and new file doesn't", async () => { - const mockMcpSettingsContent = '{"mcpServers":{"test-server":{}}}' as string - // Mock file existence checks ;(fileExistsAtPath as jest.Mock).mockImplementation(async (path: string) => { if (path === mockSettingsDir) return true diff --git a/src/activate/registerCodeActions.ts b/src/activate/registerCodeActions.ts index 88e8e218f4..31f474442d 100644 --- a/src/activate/registerCodeActions.ts +++ b/src/activate/registerCodeActions.ts @@ -3,7 +3,6 @@ import * as vscode from "vscode" import { ACTION_NAMES, COMMAND_IDS } from "../core/CodeActionProvider" import { EditorUtils } from "../core/EditorUtils" import { ClineProvider } from "../core/webview/ClineProvider" -import { telemetryService } from "../services/telemetry/TelemetryService" export const registerCodeActions = (context: vscode.ExtensionContext) => { registerCodeActionPair( diff --git a/src/activate/registerCommands.ts b/src/activate/registerCommands.ts index 1883083b6e..a2ce707dab 100644 --- a/src/activate/registerCommands.ts +++ b/src/activate/registerCommands.ts @@ -2,6 +2,10 @@ import * as vscode from "vscode" import delay from "delay" import { ClineProvider } from "../core/webview/ClineProvider" +import { ContextProxy } from "../core/config/ContextProxy" + +import { registerHumanRelayCallback, unregisterHumanRelayCallback, handleHumanRelayResponse } from "./humanRelay" +import { handleNewTask } from "./handleTask" /** * Helper to get the visible ClineProvider instance or log if not found. @@ -15,9 +19,6 @@ export function getVisibleProviderOrLog(outputChannel: vscode.OutputChannel): Cl return visibleProvider } -import { registerHumanRelayCallback, unregisterHumanRelayCallback, handleHumanRelayResponse } from "./humanRelay" -import { handleNewTask } from "./handleTask" - // Store panel references in both modes let sidebarPanel: vscode.WebviewView | undefined = undefined let tabPanel: vscode.WebviewPanel | undefined = undefined @@ -53,7 +54,7 @@ export type RegisterCommandOptions = { } export const registerCommands = (options: RegisterCommandOptions) => { - const { context, outputChannel } = options + const { context } = options for (const [command, callback] of Object.entries(getCommandsMap(options))) { context.subscriptions.push(vscode.commands.registerCommand(command, callback)) @@ -142,7 +143,8 @@ export const openClineInNewTab = async ({ context, outputChannel }: Omit editor.viewColumn || 0)) // Check if there are any visible text editors, otherwise open a new group diff --git a/src/api/providers/__tests__/glama.test.ts b/src/api/providers/__tests__/glama.test.ts index 5e017ccd0a..c7903a8e55 100644 --- a/src/api/providers/__tests__/glama.test.ts +++ b/src/api/providers/__tests__/glama.test.ts @@ -1,7 +1,6 @@ // npx jest src/api/providers/__tests__/glama.test.ts import { Anthropic } from "@anthropic-ai/sdk" -import axios from "axios" import { GlamaHandler } from "../glama" import { ApiHandlerOptions } from "../../../shared/api" @@ -20,31 +19,18 @@ jest.mock("openai", () => { const stream = { [Symbol.asyncIterator]: async function* () { yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], + choices: [{ delta: { content: "Test response" }, index: 0 }], usage: null, } yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, + choices: [{ delta: {}, index: 0 }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, } }, } const result = mockCreate(...args) + if (args[0].stream) { mockWithResponse.mockReturnValue( Promise.resolve({ @@ -59,6 +45,7 @@ jest.mock("openai", () => { ) result.withResponse = mockWithResponse } + return result }, }, @@ -73,10 +60,10 @@ describe("GlamaHandler", () => { beforeEach(() => { mockOptions = { - apiModelId: "anthropic/claude-3-7-sonnet", - glamaModelId: "anthropic/claude-3-7-sonnet", glamaApiKey: "test-api-key", + glamaModelId: "anthropic/claude-3-7-sonnet", } + handler = new GlamaHandler(mockOptions) mockCreate.mockClear() mockWithResponse.mockClear() @@ -102,7 +89,7 @@ describe("GlamaHandler", () => { describe("constructor", () => { it("should initialize with provided options", () => { expect(handler).toBeInstanceOf(GlamaHandler) - expect(handler.getModel().id).toBe(mockOptions.apiModelId) + expect(handler.getModel().id).toBe(mockOptions.glamaModelId) }) }) @@ -116,40 +103,15 @@ describe("GlamaHandler", () => { ] it("should handle streaming responses", async () => { - // Mock axios for token usage request - const mockAxios = jest.spyOn(axios, "get").mockResolvedValueOnce({ - data: { - tokenUsage: { - promptTokens: 10, - completionTokens: 5, - cacheCreationInputTokens: 0, - cacheReadInputTokens: 0, - }, - totalCostUsd: "0.00", - }, - }) - const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] + for await (const chunk of stream) { chunks.push(chunk) } - expect(chunks.length).toBe(2) // Text chunk and usage chunk - expect(chunks[0]).toEqual({ - type: "text", - text: "Test response", - }) - expect(chunks[1]).toEqual({ - type: "usage", - inputTokens: 10, - outputTokens: 5, - cacheWriteTokens: 0, - cacheReadTokens: 0, - totalCost: 0, - }) - - mockAxios.mockRestore() + expect(chunks.length).toBe(1) + expect(chunks[0]).toEqual({ type: "text", text: "Test response" }) }) it("should handle API errors", async () => { @@ -178,7 +140,7 @@ describe("GlamaHandler", () => { expect(result).toBe("Test response") expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ - model: mockOptions.apiModelId, + model: mockOptions.glamaModelId, messages: [{ role: "user", content: "Test prompt" }], temperature: 0, max_tokens: 8192, @@ -204,22 +166,16 @@ describe("GlamaHandler", () => { mockCreate.mockClear() const nonAnthropicOptions = { - apiModelId: "openai/gpt-4", - glamaModelId: "openai/gpt-4", glamaApiKey: "test-key", - glamaModelInfo: { - maxTokens: 4096, - contextWindow: 8192, - supportsImages: true, - supportsPromptCache: false, - }, + glamaModelId: "openai/gpt-4o", } + const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions) await nonAnthropicHandler.completePrompt("Test prompt") expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ - model: "openai/gpt-4", + model: "openai/gpt-4o", messages: [{ role: "user", content: "Test prompt" }], temperature: 0, }), @@ -228,13 +184,20 @@ describe("GlamaHandler", () => { }) }) - describe("getModel", () => { - it("should return model info", () => { - const modelInfo = handler.getModel() - expect(modelInfo.id).toBe(mockOptions.apiModelId) + describe("fetchModel", () => { + it("should return model info", async () => { + const modelInfo = await handler.fetchModel() + expect(modelInfo.id).toBe(mockOptions.glamaModelId) expect(modelInfo.info).toBeDefined() expect(modelInfo.info.maxTokens).toBe(8192) expect(modelInfo.info.contextWindow).toBe(200_000) }) + + it("should return default model when invalid model provided", async () => { + const handlerWithInvalidModel = new GlamaHandler({ ...mockOptions, glamaModelId: "invalid/model" }) + const modelInfo = await handlerWithInvalidModel.fetchModel() + expect(modelInfo.id).toBe("anthropic/claude-3-7-sonnet") + expect(modelInfo.info).toBeDefined() + }) }) }) diff --git a/src/api/providers/__tests__/openrouter.test.ts b/src/api/providers/__tests__/openrouter.test.ts index 92bc46249a..8cf500930f 100644 --- a/src/api/providers/__tests__/openrouter.test.ts +++ b/src/api/providers/__tests__/openrouter.test.ts @@ -1,35 +1,22 @@ // npx jest src/api/providers/__tests__/openrouter.test.ts -import axios from "axios" import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" import { OpenRouterHandler } from "../openrouter" -import { ApiHandlerOptions, ModelInfo } from "../../../shared/api" +import { ApiHandlerOptions } from "../../../shared/api" // Mock dependencies jest.mock("openai") -jest.mock("axios") jest.mock("delay", () => jest.fn(() => Promise.resolve())) -const mockOpenRouterModelInfo: ModelInfo = { - maxTokens: 1000, - contextWindow: 2000, - supportsPromptCache: false, - inputPrice: 0.01, - outputPrice: 0.02, -} - describe("OpenRouterHandler", () => { const mockOptions: ApiHandlerOptions = { openRouterApiKey: "test-key", - openRouterModelId: "test-model", - openRouterModelInfo: mockOpenRouterModelInfo, + openRouterModelId: "anthropic/claude-3.7-sonnet", } - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) it("initializes with correct options", () => { const handler = new OpenRouterHandler(mockOptions) @@ -45,62 +32,55 @@ describe("OpenRouterHandler", () => { }) }) - describe("getModel", () => { - it("returns correct model info when options are provided", () => { + describe("fetchModel", () => { + it("returns correct model info when options are provided", async () => { const handler = new OpenRouterHandler(mockOptions) - const result = handler.getModel() + const result = await handler.fetchModel() - expect(result).toEqual({ + expect(result).toMatchObject({ id: mockOptions.openRouterModelId, - info: mockOptions.openRouterModelInfo, - maxTokens: 1000, + maxTokens: 8192, thinking: undefined, temperature: 0, reasoningEffort: undefined, topP: undefined, promptCache: { - supported: false, + supported: true, optional: false, }, }) }) - it("returns default model info when options are not provided", () => { + it("returns default model info when options are not provided", async () => { const handler = new OpenRouterHandler({}) - const result = handler.getModel() - + const result = await handler.fetchModel() expect(result.id).toBe("anthropic/claude-3.7-sonnet") expect(result.info.supportsPromptCache).toBe(true) }) - it("honors custom maxTokens for thinking models", () => { + it("honors custom maxTokens for thinking models", async () => { const handler = new OpenRouterHandler({ openRouterApiKey: "test-key", - openRouterModelId: "test-model", - openRouterModelInfo: { - ...mockOpenRouterModelInfo, - maxTokens: 128_000, - thinking: true, - }, + openRouterModelId: "anthropic/claude-3.7-sonnet:thinking", modelMaxTokens: 32_768, modelMaxThinkingTokens: 16_384, }) - const result = handler.getModel() + const result = await handler.fetchModel() expect(result.maxTokens).toBe(32_768) expect(result.thinking).toEqual({ type: "enabled", budget_tokens: 16_384 }) expect(result.temperature).toBe(1.0) }) - it("does not honor custom maxTokens for non-thinking models", () => { + it("does not honor custom maxTokens for non-thinking models", async () => { const handler = new OpenRouterHandler({ ...mockOptions, modelMaxTokens: 32_768, modelMaxThinkingTokens: 16_384, }) - const result = handler.getModel() - expect(result.maxTokens).toBe(1000) + const result = await handler.fetchModel() + expect(result.maxTokens).toBe(8192) expect(result.thinking).toBeUndefined() expect(result.temperature).toBe(0) }) @@ -113,7 +93,7 @@ describe("OpenRouterHandler", () => { const mockStream = { async *[Symbol.asyncIterator]() { yield { - id: "test-id", + id: mockOptions.openRouterModelId, choices: [{ delta: { content: "test response" } }], } yield { @@ -146,16 +126,29 @@ describe("OpenRouterHandler", () => { expect(chunks[0]).toEqual({ type: "text", text: "test response" }) expect(chunks[1]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20, totalCost: 0.001 }) - // Verify OpenAI client was called with correct parameters + // Verify OpenAI client was called with correct parameters. expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ - model: mockOptions.openRouterModelId, - temperature: 0, - messages: expect.arrayContaining([ - { role: "system", content: systemPrompt }, - { role: "user", content: "test message" }, - ]), + max_tokens: 8192, + messages: [ + { + content: [ + { cache_control: { type: "ephemeral" }, text: "test system prompt", type: "text" }, + ], + role: "system", + }, + { + content: [{ cache_control: { type: "ephemeral" }, text: "test message", type: "text" }], + role: "user", + }, + ], + model: "anthropic/claude-3.7-sonnet", stream: true, + stream_options: { include_usage: true }, + temperature: 0, + thinking: undefined, + top_p: undefined, + transforms: ["middle-out"], }), ) }) @@ -178,7 +171,6 @@ describe("OpenRouterHandler", () => { ;(OpenAI as jest.MockedClass).prototype.chat = { completions: { create: mockCreate }, } as any - ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) await handler.createMessage("test", []).next() @@ -188,10 +180,6 @@ describe("OpenRouterHandler", () => { it("adds cache control for supported models", async () => { const handler = new OpenRouterHandler({ ...mockOptions, - openRouterModelInfo: { - ...mockOpenRouterModelInfo, - supportsPromptCache: true, - }, openRouterModelId: "anthropic/claude-3.5-sonnet", }) @@ -208,7 +196,6 @@ describe("OpenRouterHandler", () => { ;(OpenAI as jest.MockedClass).prototype.chat = { completions: { create: mockCreate }, } as any - ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } }) const messages: Anthropic.Messages.MessageParam[] = [ { role: "user", content: "message 1" }, @@ -266,7 +253,7 @@ describe("OpenRouterHandler", () => { expect(mockCreate).toHaveBeenCalledWith({ model: mockOptions.openRouterModelId, - max_tokens: 1000, + max_tokens: 8192, thinking: undefined, temperature: 0, messages: [{ role: "user", content: "test prompt" }], diff --git a/src/api/providers/__tests__/requesty.test.ts b/src/api/providers/__tests__/requesty.test.ts index 2b3da4a7ad..53dda2637e 100644 --- a/src/api/providers/__tests__/requesty.test.ts +++ b/src/api/providers/__tests__/requesty.test.ts @@ -14,22 +14,23 @@ describe("RequestyHandler", () => { let handler: RequestyHandler let mockCreate: jest.Mock + const modelInfo: ModelInfo = { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsComputerUse: true, + supportsPromptCache: true, + inputPrice: 3.0, + outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, + description: + "Claude 3.7 Sonnet is an advanced large language model with improved reasoning, coding, and problem-solving capabilities. It introduces a hybrid reasoning approach, allowing users to choose between rapid responses and extended, step-by-step processing for complex tasks. The model demonstrates notable improvements in coding, particularly in front-end development and full-stack updates, and excels in agentic workflows, where it can autonomously navigate multi-step processes. Claude 3.7 Sonnet maintains performance parity with its predecessor in standard mode while offering an extended reasoning mode for enhanced accuracy in math, coding, and instruction-following tasks. Read more at the [blog post here](https://www.anthropic.com/news/claude-3-7-sonnet)", + } + const defaultOptions: ApiHandlerOptions = { requestyApiKey: "test-key", requestyModelId: "test-model", - requestyModelInfo: { - maxTokens: 8192, - contextWindow: 200_000, - supportsImages: true, - supportsComputerUse: true, - supportsPromptCache: true, - inputPrice: 3.0, - outputPrice: 15.0, - cacheWritesPrice: 3.75, - cacheReadsPrice: 0.3, - description: - "Claude 3.7 Sonnet is an advanced large language model with improved reasoning, coding, and problem-solving capabilities. It introduces a hybrid reasoning approach, allowing users to choose between rapid responses and extended, step-by-step processing for complex tasks. The model demonstrates notable improvements in coding, particularly in front-end development and full-stack updates, and excels in agentic workflows, where it can autonomously navigate multi-step processes. Claude 3.7 Sonnet maintains performance parity with its predecessor in standard mode while offering an extended reasoning mode for enhanced accuracy in math, coding, and instruction-following tasks. Read more at the [blog post here](https://www.anthropic.com/news/claude-3-7-sonnet)", - }, openAiStreamingEnabled: true, includeMaxTokens: true, // Add this to match the implementation } @@ -185,7 +186,7 @@ describe("RequestyHandler", () => { ], stream: true, stream_options: { include_usage: true }, - max_tokens: defaultOptions.requestyModelInfo?.maxTokens, + max_tokens: modelInfo.maxTokens, }) }) @@ -279,20 +280,17 @@ describe("RequestyHandler", () => { const result = handler.getModel() expect(result).toEqual({ id: defaultOptions.requestyModelId, - info: defaultOptions.requestyModelInfo, + info: modelInfo, }) }) it("should use sane defaults when no model info provided", () => { - handler = new RequestyHandler({ - ...defaultOptions, - requestyModelInfo: undefined, - }) - + handler = new RequestyHandler(defaultOptions) const result = handler.getModel() + expect(result).toEqual({ id: defaultOptions.requestyModelId, - info: defaultOptions.requestyModelInfo, + info: modelInfo, }) }) }) diff --git a/src/api/providers/__tests__/unbound.test.ts b/src/api/providers/__tests__/unbound.test.ts index 8d647ec2db..e174eb7e99 100644 --- a/src/api/providers/__tests__/unbound.test.ts +++ b/src/api/providers/__tests__/unbound.test.ts @@ -1,7 +1,11 @@ -import { UnboundHandler } from "../unbound" -import { ApiHandlerOptions } from "../../../shared/api" +// npx jest src/api/providers/__tests__/unbound.test.ts + import { Anthropic } from "@anthropic-ai/sdk" +import { ApiHandlerOptions } from "../../../shared/api" + +import { UnboundHandler } from "../unbound" + // Mock OpenAI client const mockCreate = jest.fn() const mockWithResponse = jest.fn() @@ -17,12 +21,7 @@ jest.mock("openai", () => { [Symbol.asyncIterator]: async function* () { // First chunk with content yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], + choices: [{ delta: { content: "Test response" }, index: 0 }], } // Second chunk with usage data yield { @@ -48,15 +47,14 @@ jest.mock("openai", () => { } const result = mockCreate(...args) + if (args[0].stream) { mockWithResponse.mockReturnValue( - Promise.resolve({ - data: stream, - response: { headers: new Map() }, - }), + Promise.resolve({ data: stream, response: { headers: new Map() } }), ) result.withResponse = mockWithResponse } + return result }, }, @@ -71,18 +69,10 @@ describe("UnboundHandler", () => { beforeEach(() => { mockOptions = { - apiModelId: "anthropic/claude-3-5-sonnet-20241022", unboundApiKey: "test-api-key", unboundModelId: "anthropic/claude-3-5-sonnet-20241022", - unboundModelInfo: { - description: "Anthropic's Claude 3 Sonnet model", - maxTokens: 8192, - contextWindow: 200000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.02, - }, } + handler = new UnboundHandler(mockOptions) mockCreate.mockClear() mockWithResponse.mockClear() @@ -101,9 +91,9 @@ describe("UnboundHandler", () => { }) describe("constructor", () => { - it("should initialize with provided options", () => { + it("should initialize with provided options", async () => { expect(handler).toBeInstanceOf(UnboundHandler) - expect(handler.getModel().id).toBe(mockOptions.apiModelId) + expect((await handler.fetchModel()).id).toBe(mockOptions.unboundModelId) }) }) @@ -119,6 +109,7 @@ describe("UnboundHandler", () => { it("should handle streaming responses with text and usage data", async () => { const stream = handler.createMessage(systemPrompt, messages) const chunks: Array<{ type: string } & Record> = [] + for await (const chunk of stream) { chunks.push(chunk) } @@ -126,17 +117,10 @@ describe("UnboundHandler", () => { expect(chunks.length).toBe(3) // Verify text chunk - expect(chunks[0]).toEqual({ - type: "text", - text: "Test response", - }) + expect(chunks[0]).toEqual({ type: "text", text: "Test response" }) // Verify regular usage data - expect(chunks[1]).toEqual({ - type: "usage", - inputTokens: 10, - outputTokens: 5, - }) + expect(chunks[1]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 5 }) // Verify usage data with cache information expect(chunks[2]).toEqual({ @@ -153,6 +137,7 @@ describe("UnboundHandler", () => { messages: expect.any(Array), stream: true, }), + expect.objectContaining({ headers: { "X-Unbound-Metadata": expect.stringContaining("roo-code"), @@ -173,6 +158,7 @@ describe("UnboundHandler", () => { for await (const chunk of stream) { chunks.push(chunk) } + fail("Expected error to be thrown") } catch (error) { expect(error).toBeInstanceOf(Error) @@ -185,6 +171,7 @@ describe("UnboundHandler", () => { it("should complete prompt successfully", async () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: "claude-3-5-sonnet-20241022", @@ -206,9 +193,7 @@ describe("UnboundHandler", () => { }) it("should handle empty response", async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: "" } }], - }) + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "" } }] }) const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) @@ -216,22 +201,14 @@ describe("UnboundHandler", () => { it("should not set max_tokens for non-Anthropic models", async () => { mockCreate.mockClear() - const nonAnthropicOptions = { + const nonAnthropicHandler = new UnboundHandler({ apiModelId: "openai/gpt-4o", unboundApiKey: "test-key", unboundModelId: "openai/gpt-4o", - unboundModelInfo: { - description: "OpenAI's GPT-4", - maxTokens: undefined, - contextWindow: 128000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.03, - }, - } - const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions) + }) await nonAnthropicHandler.completePrompt("Test prompt") + expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: "gpt-4o", @@ -244,27 +221,21 @@ describe("UnboundHandler", () => { }), }), ) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens") }) it("should not set temperature for openai/o3-mini", async () => { mockCreate.mockClear() - const openaiOptions = { + const openaiHandler = new UnboundHandler({ apiModelId: "openai/o3-mini", unboundApiKey: "test-key", unboundModelId: "openai/o3-mini", - unboundModelInfo: { - maxTokens: undefined, - contextWindow: 128000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.03, - }, - } - const openaiHandler = new UnboundHandler(openaiOptions) + }) await openaiHandler.completePrompt("Test prompt") + expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: "o3-mini", @@ -276,25 +247,22 @@ describe("UnboundHandler", () => { }), }), ) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("temperature") }) }) - describe("getModel", () => { - it("should return model info", () => { - const modelInfo = handler.getModel() - expect(modelInfo.id).toBe(mockOptions.apiModelId) + describe("fetchModel", () => { + it("should return model info", async () => { + const modelInfo = await handler.fetchModel() + expect(modelInfo.id).toBe(mockOptions.unboundModelId) expect(modelInfo.info).toBeDefined() }) - it("should return default model when invalid model provided", () => { - const handlerWithInvalidModel = new UnboundHandler({ - ...mockOptions, - unboundModelId: "invalid/model", - unboundModelInfo: undefined, - }) - const modelInfo = handlerWithInvalidModel.getModel() - expect(modelInfo.id).toBe("anthropic/claude-3-7-sonnet-20250219") // Default model + it("should return default model when invalid model provided", async () => { + const handlerWithInvalidModel = new UnboundHandler({ ...mockOptions, unboundModelId: "invalid/model" }) + const modelInfo = await handlerWithInvalidModel.fetchModel() + expect(modelInfo.id).toBe("anthropic/claude-3-7-sonnet-20250219") expect(modelInfo.info).toBeDefined() }) }) diff --git a/src/api/providers/fetchers/cache.ts b/src/api/providers/fetchers/cache.ts new file mode 100644 index 0000000000..ab6dcce021 --- /dev/null +++ b/src/api/providers/fetchers/cache.ts @@ -0,0 +1,82 @@ +import * as path from "path" +import fs from "fs/promises" + +import NodeCache from "node-cache" + +import { ContextProxy } from "../../../core/config/ContextProxy" +import { getCacheDirectoryPath } from "../../../shared/storagePathManager" +import { RouterName, ModelRecord } from "../../../shared/api" +import { fileExistsAtPath } from "../../../utils/fs" + +import { getOpenRouterModels } from "./openrouter" +import { getRequestyModels } from "./requesty" +import { getGlamaModels } from "./glama" +import { getUnboundModels } from "./unbound" + +const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) + +async function writeModels(router: RouterName, data: ModelRecord) { + const filename = `${router}_models.json` + const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) + await fs.writeFile(path.join(cacheDir, filename), JSON.stringify(data)) +} + +async function readModels(router: RouterName): Promise { + const filename = `${router}_models.json` + const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) + const filePath = path.join(cacheDir, filename) + const exists = await fileExistsAtPath(filePath) + return exists ? JSON.parse(await fs.readFile(filePath, "utf8")) : undefined +} + +/** + * Get models from the cache or fetch them from the provider and cache them. + * There are two caches: + * 1. Memory cache - This is a simple in-memory cache that is used to store models for a short period of time. + * 2. File cache - This is a file-based cache that is used to store models for a longer period of time. + * + * @param router - The router to fetch models from. + * @returns The models from the cache or the fetched models. + */ +export const getModels = async (router: RouterName): Promise => { + let models = memoryCache.get(router) + + if (models) { + // console.log(`[getModels] NodeCache hit for ${router} -> ${Object.keys(models).length}`) + return models + } + + switch (router) { + case "openrouter": + models = await getOpenRouterModels() + break + case "requesty": + models = await getRequestyModels() + break + case "glama": + models = await getGlamaModels() + break + case "unbound": + models = await getUnboundModels() + break + } + + if (Object.keys(models).length > 0) { + // console.log(`[getModels] API fetch for ${router} -> ${Object.keys(models).length}`) + memoryCache.set(router, models) + + try { + await writeModels(router, models) + // console.log(`[getModels] wrote ${router} models to file cache`) + } catch (error) {} + + return models + } + + try { + models = await readModels(router) + // console.log(`[getModels] read ${router} models from file cache`) + } catch (error) {} + + return models ?? {} +} diff --git a/src/api/providers/fetchers/glama.ts b/src/api/providers/fetchers/glama.ts new file mode 100644 index 0000000000..82ceba5233 --- /dev/null +++ b/src/api/providers/fetchers/glama.ts @@ -0,0 +1,42 @@ +import axios from "axios" + +import { ModelInfo } from "../../../shared/api" +import { parseApiPrice } from "../../../utils/cost" + +export async function getGlamaModels(): Promise> { + const models: Record = {} + + try { + const response = await axios.get("https://glama.ai/api/gateway/v1/models") + const rawModels = response.data + + for (const rawModel of rawModels) { + const modelInfo: ModelInfo = { + maxTokens: rawModel.maxTokensOutput, + contextWindow: rawModel.maxTokensInput, + supportsImages: rawModel.capabilities?.includes("input:image"), + supportsComputerUse: rawModel.capabilities?.includes("computer_use"), + supportsPromptCache: rawModel.capabilities?.includes("caching"), + inputPrice: parseApiPrice(rawModel.pricePerToken?.input), + outputPrice: parseApiPrice(rawModel.pricePerToken?.output), + description: undefined, + cacheWritesPrice: parseApiPrice(rawModel.pricePerToken?.cacheWrite), + cacheReadsPrice: parseApiPrice(rawModel.pricePerToken?.cacheRead), + } + + switch (rawModel.id) { + case rawModel.id.startsWith("anthropic/"): + modelInfo.maxTokens = 8192 + break + default: + break + } + + models[rawModel.id] = modelInfo + } + } catch (error) { + console.error(`Error fetching Glama models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + return models +} diff --git a/src/api/providers/fetchers/openrouter.ts b/src/api/providers/fetchers/openrouter.ts index 99dc41faf9..db0ac5a0ca 100644 --- a/src/api/providers/fetchers/openrouter.ts +++ b/src/api/providers/fetchers/openrouter.ts @@ -46,7 +46,7 @@ const openRouterModelsResponseSchema = z.object({ type OpenRouterModelsResponse = z.infer -export async function getOpenRouterModels(options?: ApiHandlerOptions) { +export async function getOpenRouterModels(options?: ApiHandlerOptions): Promise> { const models: Record = {} const baseURL = options?.openRouterBaseUrl || "https://openrouter.ai/api/v1" diff --git a/src/api/providers/fetchers/requesty.ts b/src/api/providers/fetchers/requesty.ts new file mode 100644 index 0000000000..7fe6e41a2b --- /dev/null +++ b/src/api/providers/fetchers/requesty.ts @@ -0,0 +1,41 @@ +import axios from "axios" + +import { ModelInfo } from "../../../shared/api" +import { parseApiPrice } from "../../../utils/cost" + +export async function getRequestyModels(apiKey?: string): Promise> { + const models: Record = {} + + try { + const headers: Record = {} + + if (apiKey) { + headers["Authorization"] = `Bearer ${apiKey}` + } + + const url = "https://router.requesty.ai/v1/models" + const response = await axios.get(url, { headers }) + const rawModels = response.data.data + + for (const rawModel of rawModels) { + const modelInfo: ModelInfo = { + maxTokens: rawModel.max_output_tokens, + contextWindow: rawModel.context_window, + supportsPromptCache: rawModel.supports_caching, + supportsImages: rawModel.supports_vision, + supportsComputerUse: rawModel.supports_computer_use, + inputPrice: parseApiPrice(rawModel.input_price), + outputPrice: parseApiPrice(rawModel.output_price), + description: rawModel.description, + cacheWritesPrice: parseApiPrice(rawModel.caching_price), + cacheReadsPrice: parseApiPrice(rawModel.cached_price), + } + + models[rawModel.id] = modelInfo + } + } catch (error) { + console.error(`Error fetching Requesty models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + return models +} diff --git a/src/api/providers/fetchers/unbound.ts b/src/api/providers/fetchers/unbound.ts new file mode 100644 index 0000000000..73a8c2f897 --- /dev/null +++ b/src/api/providers/fetchers/unbound.ts @@ -0,0 +1,46 @@ +import axios from "axios" + +import { ModelInfo } from "../../../shared/api" + +export async function getUnboundModels(): Promise> { + const models: Record = {} + + try { + const response = await axios.get("https://api.getunbound.ai/models") + + if (response.data) { + const rawModels: Record = response.data + + for (const [modelId, model] of Object.entries(rawModels)) { + const modelInfo: ModelInfo = { + maxTokens: model?.maxTokens ? parseInt(model.maxTokens) : undefined, + contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0, + supportsImages: model?.supportsImages ?? false, + supportsPromptCache: model?.supportsPromptCaching ?? false, + supportsComputerUse: model?.supportsComputerUse ?? false, + inputPrice: model?.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined, + outputPrice: model?.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined, + cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined, + cacheReadsPrice: model?.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined, + } + + switch (true) { + case modelId.startsWith("anthropic/"): + // Set max tokens to 8192 for supported Anthropic models + if (modelInfo.maxTokens !== 4096) { + modelInfo.maxTokens = 8192 + } + break + default: + break + } + + models[modelId] = modelInfo + } + } + } catch (error) { + console.error(`Error fetching Unbound models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + return models +} diff --git a/src/api/providers/glama.ts b/src/api/providers/glama.ts index 43b6ebfb7a..72109f6672 100644 --- a/src/api/providers/glama.ts +++ b/src/api/providers/glama.ts @@ -2,119 +2,64 @@ import { Anthropic } from "@anthropic-ai/sdk" import axios from "axios" import OpenAI from "openai" -import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api" -import { parseApiPrice } from "../../utils/cost" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiHandlerOptions, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api" import { ApiStream } from "../transform/stream" -import { SingleCompletionHandler } from "../" -import { BaseProvider } from "./base-provider" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { addCacheControlDirectives } from "../transform/caching" +import { SingleCompletionHandler } from "../index" +import { RouterProvider } from "./router-provider" const GLAMA_DEFAULT_TEMPERATURE = 0 -export class GlamaHandler extends BaseProvider implements SingleCompletionHandler { - protected options: ApiHandlerOptions - private client: OpenAI +const DEFAULT_HEADERS = { + "X-Glama-Metadata": JSON.stringify({ labels: [{ key: "app", value: "vscode.rooveterinaryinc.roo-cline" }] }), +} +export class GlamaHandler extends RouterProvider implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super() - this.options = options - const baseURL = "https://glama.ai/api/gateway/openai/v1" - const apiKey = this.options.glamaApiKey ?? "not-provided" - this.client = new OpenAI({ baseURL, apiKey }) - } - - private supportsTemperature(): boolean { - return !this.getModel().id.startsWith("openai/o3-mini") - } - - override getModel(): { id: string; info: ModelInfo } { - const modelId = this.options.glamaModelId - const modelInfo = this.options.glamaModelInfo - - if (modelId && modelInfo) { - return { id: modelId, info: modelInfo } - } - - return { id: glamaDefaultModelId, info: glamaDefaultModelInfo } + super({ + options, + name: "unbound", + baseURL: "https://glama.ai/api/gateway/openai/v1", + apiKey: options.glamaApiKey, + modelId: options.glamaModelId, + defaultModelId: glamaDefaultModelId, + defaultModelInfo: glamaDefaultModelInfo, + }) } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - // Convert Anthropic messages to OpenAI format + const { id: modelId, info } = await this.fetchModel() + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages), ] - // this is specifically for claude models (some models may 'support prompt caching' automatically without this) - if (this.getModel().id.startsWith("anthropic/claude-3")) { - openAiMessages[0] = { - role: "system", - content: [ - { - type: "text", - text: systemPrompt, - // @ts-ignore-next-line - cache_control: { type: "ephemeral" }, - }, - ], - } - - // Add cache_control to the last two user messages - // (note: this works because we only ever add one user message at a time, - // but if we added multiple we'd need to mark the user message before the last assistant message) - const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) - lastTwoUserMessages.forEach((msg) => { - if (typeof msg.content === "string") { - msg.content = [{ type: "text", text: msg.content }] - } - if (Array.isArray(msg.content)) { - // NOTE: this is fine since env details will always be added at the end. - // but if it weren't there, and the user added a image_url type message, - // it would pop a text part before it and then move it after to the end. - let lastTextPart = msg.content.filter((part) => part.type === "text").pop() - - if (!lastTextPart) { - lastTextPart = { type: "text", text: "..." } - msg.content.push(lastTextPart) - } - // @ts-ignore-next-line - lastTextPart["cache_control"] = { type: "ephemeral" } - } - }) + if (modelId.startsWith("anthropic/claude-3")) { + addCacheControlDirectives(systemPrompt, openAiMessages) } - // Required by Anthropic - // Other providers default to max tokens allowed. + // Required by Anthropic; other providers default to max tokens allowed. let maxTokens: number | undefined - if (this.getModel().id.startsWith("anthropic/")) { - maxTokens = this.getModel().info.maxTokens ?? undefined + if (modelId.startsWith("anthropic/")) { + maxTokens = info.maxTokens ?? undefined } const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { - model: this.getModel().id, + model: modelId, max_tokens: maxTokens, messages: openAiMessages, stream: true, } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? GLAMA_DEFAULT_TEMPERATURE } const { data: completion, response } = await this.client.chat.completions - .create(requestOptions, { - headers: { - "X-Glama-Metadata": JSON.stringify({ - labels: [ - { - key: "app", - value: "vscode.rooveterinaryinc.roo-cline", - }, - ], - }), - }, - }) + .create(requestOptions, { headers: DEFAULT_HEADERS }) .withResponse() const completionRequestId = response.headers.get("x-completion-request-id") @@ -123,10 +68,7 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle const delta = chunk.choices[0]?.delta if (delta?.content) { - yield { - type: "text", - text: delta.content, - } + yield { type: "text", text: delta.content } } } @@ -140,11 +82,7 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle // before we can fetch information about the token usage and cost. const response = await axios.get( `https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, - { - headers: { - Authorization: `Bearer ${this.options.glamaApiKey}`, - }, - }, + { headers: { Authorization: `Bearer ${this.options.glamaApiKey}` } }, ) const completionRequest = response.data @@ -170,18 +108,20 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle } async completePrompt(prompt: string): Promise { + const { id: modelId, info } = await this.fetchModel() + try { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: this.getModel().id, + model: modelId, messages: [{ role: "user", content: prompt }], } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? GLAMA_DEFAULT_TEMPERATURE } - if (this.getModel().id.startsWith("anthropic/")) { - requestOptions.max_tokens = this.getModel().info.maxTokens + if (modelId.startsWith("anthropic/")) { + requestOptions.max_tokens = info.maxTokens } const response = await this.client.chat.completions.create(requestOptions) @@ -190,45 +130,8 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle if (error instanceof Error) { throw new Error(`Glama completion error: ${error.message}`) } - throw error - } - } -} -export async function getGlamaModels() { - const models: Record = {} - - try { - const response = await axios.get("https://glama.ai/api/gateway/v1/models") - const rawModels = response.data - - for (const rawModel of rawModels) { - const modelInfo: ModelInfo = { - maxTokens: rawModel.maxTokensOutput, - contextWindow: rawModel.maxTokensInput, - supportsImages: rawModel.capabilities?.includes("input:image"), - supportsComputerUse: rawModel.capabilities?.includes("computer_use"), - supportsPromptCache: rawModel.capabilities?.includes("caching"), - inputPrice: parseApiPrice(rawModel.pricePerToken?.input), - outputPrice: parseApiPrice(rawModel.pricePerToken?.output), - description: undefined, - cacheWritesPrice: parseApiPrice(rawModel.pricePerToken?.cacheWrite), - cacheReadsPrice: parseApiPrice(rawModel.pricePerToken?.cacheRead), - } - - switch (rawModel.id) { - case rawModel.id.startsWith("anthropic/"): - modelInfo.maxTokens = 8192 - break - default: - break - } - - models[rawModel.id] = modelInfo + throw error } - } catch (error) { - console.error(`Error fetching Glama models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) } - - return models } diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 4b6a7982ca..71568dfde1 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -74,7 +74,6 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const enabledR1Format = this.options.openAiR1FormatEnabled ?? false const enabledLegacyFormat = this.options.openAiLegacyFormat ?? false const isAzureAiInference = this._isAzureAiInference(modelUrl) - const urlHost = this._getUrlHost(modelUrl) const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format const ark = modelUrl.includes(".volces.com") diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index a9162495ee..ed19c8496e 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -4,6 +4,7 @@ import OpenAI from "openai" import { ApiHandlerOptions, + ModelRecord, openRouterDefaultModelId, openRouterDefaultModelInfo, PROMPT_CACHING_MODELS, @@ -16,6 +17,7 @@ import { convertToR1Format } from "../transform/r1-format" import { getModelParams, SingleCompletionHandler } from "../index" import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants" import { BaseProvider } from "./base-provider" +import { getModels } from "./fetchers/cache" const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]" @@ -51,6 +53,7 @@ interface CompletionUsage { export class OpenRouterHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: OpenAI + protected models: ModelRecord = {} constructor(options: ApiHandlerOptions) { super() @@ -66,7 +69,15 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH systemPrompt: string, messages: Anthropic.Messages.MessageParam[], ): AsyncGenerator { - let { id: modelId, maxTokens, thinking, temperature, topP, reasoningEffort, promptCache } = this.getModel() + let { + id: modelId, + maxTokens, + thinking, + temperature, + topP, + reasoningEffort, + promptCache, + } = await this.fetchModel() // Convert Anthropic messages to OpenAI format. let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ @@ -120,8 +131,6 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } // https://openrouter.ai/docs/transforms - let fullResponseText = "" - const completionParams: OpenRouterChatCompletionParams = { model: modelId, max_tokens: maxTokens, @@ -160,7 +169,6 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } if (delta?.content) { - fullResponseText += delta.content yield { type: "text", text: delta.content } } @@ -183,22 +191,27 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } } + public async fetchModel() { + this.models = await getModels("openrouter") + return this.getModel() + } + override getModel() { - const modelId = this.options.openRouterModelId - const modelInfo = this.options.openRouterModelInfo + const id = this.options.openRouterModelId ?? openRouterDefaultModelId + const info = this.models[id] ?? openRouterDefaultModelInfo - let id = modelId ?? openRouterDefaultModelId - const info = modelInfo ?? openRouterDefaultModelInfo - const isDeepSeekR1 = id.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning" - const defaultTemperature = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0 - const topP = isDeepSeekR1 ? 0.95 : undefined + const isDeepSeekR1 = id.startsWith("deepseek/deepseek-r1") || id === "perplexity/sonar-reasoning" return { id, info, // maxTokens, thinking, temperature, reasoningEffort - ...getModelParams({ options: this.options, model: info, defaultTemperature }), - topP, + ...getModelParams({ + options: this.options, + model: info, + defaultTemperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0, + }), + topP: isDeepSeekR1 ? 0.95 : undefined, promptCache: { supported: PROMPT_CACHING_MODELS.has(id), optional: OPTIONAL_PROMPT_CACHING_MODELS.has(id), @@ -207,7 +220,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } async completePrompt(prompt: string) { - let { id: modelId, maxTokens, thinking, temperature } = this.getModel() + let { id: modelId, maxTokens, thinking, temperature } = await this.fetchModel() const completionParams: OpenRouterChatCompletionParams = { model: modelId, diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index 680b6e7179..9fe976bb51 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -1,10 +1,11 @@ -import axios from "axios" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" -import { ModelInfo, requestyDefaultModelInfo, requestyDefaultModelId } from "../../shared/api" -import { calculateApiCostOpenAI, parseApiPrice } from "../../utils/cost" -import { ApiStreamUsageChunk } from "../transform/stream" +import { ModelInfo, ModelRecord, requestyDefaultModelId, requestyDefaultModelInfo } from "../../shared/api" +import { calculateApiCostOpenAI } from "../../utils/cost" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { OpenAiHandler, OpenAiHandlerOptions } from "./openai" -import OpenAI from "openai" +import { getModels } from "./fetchers/cache" // Requesty usage includes an extra field for Anthropic use cases. // Safely cast the prompt token details section to the appropriate structure. @@ -17,25 +18,30 @@ interface RequestyUsage extends OpenAI.CompletionUsage { } export class RequestyHandler extends OpenAiHandler { + protected models: ModelRecord = {} + constructor(options: OpenAiHandlerOptions) { if (!options.requestyApiKey) { throw new Error("Requesty API key is required. Please provide it in the settings.") } + super({ ...options, openAiApiKey: options.requestyApiKey, openAiModelId: options.requestyModelId ?? requestyDefaultModelId, openAiBaseUrl: "https://router.requesty.ai/v1", - openAiCustomModelInfo: options.requestyModelInfo ?? requestyDefaultModelInfo, }) } + override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + this.models = await getModels("requesty") + yield* super.createMessage(systemPrompt, messages) + } + override getModel(): { id: string; info: ModelInfo } { - const modelId = this.options.requestyModelId ?? requestyDefaultModelId - return { - id: modelId, - info: this.options.requestyModelInfo ?? requestyDefaultModelInfo, - } + const id = this.options.requestyModelId ?? requestyDefaultModelId + const info = this.models[id] ?? requestyDefaultModelInfo + return { id, info } } protected override processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk { @@ -47,6 +53,7 @@ export class RequestyHandler extends OpenAiHandler { const totalCost = modelInfo ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) : 0 + return { type: "usage", inputTokens: inputTokens, @@ -56,56 +63,9 @@ export class RequestyHandler extends OpenAiHandler { totalCost: totalCost, } } -} - -export async function getRequestyModels(apiKey?: string) { - const models: Record = {} - - try { - const headers: Record = {} - if (apiKey) { - headers["Authorization"] = `Bearer ${apiKey}` - } - - const url = "https://router.requesty.ai/v1/models" - const response = await axios.get(url, { headers }) - const rawModels = response.data.data - for (const rawModel of rawModels) { - // { - // id: "anthropic/claude-3-5-sonnet-20240620", - // object: "model", - // created: 1740552655, - // owned_by: "system", - // input_price: 0.0000028, - // caching_price: 0.00000375, - // cached_price: 3e-7, - // output_price: 0.000015, - // max_output_tokens: 8192, - // context_window: 200000, - // supports_caching: true, - // description: - // "Anthropic's previous most intelligent model. High level of intelligence and capability. Excells in coding.", - // } - - const modelInfo: ModelInfo = { - maxTokens: rawModel.max_output_tokens, - contextWindow: rawModel.context_window, - supportsPromptCache: rawModel.supports_caching, - supportsImages: rawModel.supports_vision, - supportsComputerUse: rawModel.supports_computer_use, - inputPrice: parseApiPrice(rawModel.input_price), - outputPrice: parseApiPrice(rawModel.output_price), - description: rawModel.description, - cacheWritesPrice: parseApiPrice(rawModel.caching_price), - cacheReadsPrice: parseApiPrice(rawModel.cached_price), - } - - models[rawModel.id] = modelInfo - } - } catch (error) { - console.error(`Error fetching Requesty models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + override async completePrompt(prompt: string): Promise { + this.models = await getModels("requesty") + return super.completePrompt(prompt) } - - return models } diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts new file mode 100644 index 0000000000..5b680b1b1d --- /dev/null +++ b/src/api/providers/router-provider.ts @@ -0,0 +1,62 @@ +import OpenAI from "openai" + +import { ApiHandlerOptions, RouterName, ModelRecord, ModelInfo } from "../../shared/api" +import { BaseProvider } from "./base-provider" +import { getModels } from "./fetchers/cache" + +type RouterProviderOptions = { + name: RouterName + baseURL: string + apiKey?: string + modelId?: string + defaultModelId: string + defaultModelInfo: ModelInfo + options: ApiHandlerOptions +} + +export abstract class RouterProvider extends BaseProvider { + protected readonly options: ApiHandlerOptions + protected readonly name: RouterName + protected models: ModelRecord = {} + protected readonly modelId?: string + protected readonly defaultModelId: string + protected readonly defaultModelInfo: ModelInfo + protected readonly client: OpenAI + + constructor({ + options, + name, + baseURL, + apiKey = "not-provided", + modelId, + defaultModelId, + defaultModelInfo, + }: RouterProviderOptions) { + super() + + this.options = options + this.name = name + this.modelId = modelId + this.defaultModelId = defaultModelId + this.defaultModelInfo = defaultModelInfo + + this.client = new OpenAI({ baseURL, apiKey }) + } + + public async fetchModel() { + this.models = await getModels(this.name) + return this.getModel() + } + + override getModel(): { id: string; info: ModelInfo } { + const id = this.modelId ?? this.defaultModelId + + return this.models[id] + ? { id, info: this.models[id] } + : { id: this.defaultModelId, info: this.defaultModelInfo } + } + + protected supportsTemperature(modelId: string): boolean { + return !modelId.startsWith("openai/o3-mini") + } +} diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index 0413c96f29..27c10313a6 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -1,111 +1,67 @@ import { Anthropic } from "@anthropic-ai/sdk" -import axios from "axios" import OpenAI from "openai" -import { ApiHandlerOptions, ModelInfo, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiHandlerOptions, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" -import { SingleCompletionHandler } from "../" -import { BaseProvider } from "./base-provider" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { addCacheControlDirectives } from "../transform/caching" +import { SingleCompletionHandler } from "../index" +import { RouterProvider } from "./router-provider" + +const DEFAULT_HEADERS = { + "X-Unbound-Metadata": JSON.stringify({ labels: [{ key: "app", value: "roo-code" }] }), +} interface UnboundUsage extends OpenAI.CompletionUsage { cache_creation_input_tokens?: number cache_read_input_tokens?: number } -export class UnboundHandler extends BaseProvider implements SingleCompletionHandler { - protected options: ApiHandlerOptions - private client: OpenAI - +export class UnboundHandler extends RouterProvider implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super() - this.options = options - const baseURL = "https://api.getunbound.ai/v1" - const apiKey = this.options.unboundApiKey ?? "not-provided" - this.client = new OpenAI({ baseURL, apiKey }) - } - - private supportsTemperature(): boolean { - return !this.getModel().id.startsWith("openai/o3-mini") + super({ + options, + name: "unbound", + baseURL: "https://api.getunbound.ai/v1", + apiKey: options.unboundApiKey, + modelId: options.unboundModelId, + defaultModelId: unboundDefaultModelId, + defaultModelInfo: unboundDefaultModelInfo, + }) } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - // Convert Anthropic messages to OpenAI format + const { id: modelId, info } = await this.fetchModel() + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages), ] - // this is specifically for claude models (some models may 'support prompt caching' automatically without this) - if (this.getModel().id.startsWith("anthropic/claude-3")) { - openAiMessages[0] = { - role: "system", - content: [ - { - type: "text", - text: systemPrompt, - // @ts-ignore-next-line - cache_control: { type: "ephemeral" }, - }, - ], - } - - // Add cache_control to the last two user messages - // (note: this works because we only ever add one user message at a time, - // but if we added multiple we'd need to mark the user message before the last assistant message) - const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) - lastTwoUserMessages.forEach((msg) => { - if (typeof msg.content === "string") { - msg.content = [{ type: "text", text: msg.content }] - } - if (Array.isArray(msg.content)) { - // NOTE: this is fine since env details will always be added at the end. - // but if it weren't there, and the user added a image_url type message, - // it would pop a text part before it and then move it after to the end. - let lastTextPart = msg.content.filter((part) => part.type === "text").pop() - - if (!lastTextPart) { - lastTextPart = { type: "text", text: "..." } - msg.content.push(lastTextPart) - } - // @ts-ignore-next-line - lastTextPart["cache_control"] = { type: "ephemeral" } - } - }) + if (modelId.startsWith("anthropic/claude-3")) { + addCacheControlDirectives(systemPrompt, openAiMessages) } - // Required by Anthropic - // Other providers default to max tokens allowed. + // Required by Anthropic; other providers default to max tokens allowed. let maxTokens: number | undefined - if (this.getModel().id.startsWith("anthropic/")) { - maxTokens = this.getModel().info.maxTokens ?? undefined + if (modelId.startsWith("anthropic/")) { + maxTokens = info.maxTokens ?? undefined } const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: this.getModel().id.split("/")[1], + model: modelId.split("/")[1], max_tokens: maxTokens, messages: openAiMessages, stream: true, } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? 0 } - const { data: completion, response } = await this.client.chat.completions - .create(requestOptions, { - headers: { - "X-Unbound-Metadata": JSON.stringify({ - labels: [ - { - key: "app", - value: "roo-code", - }, - ], - }), - }, - }) + const { data: completion } = await this.client.chat.completions + .create(requestOptions, { headers: DEFAULT_HEADERS }) .withResponse() for await (const chunk of completion) { @@ -113,10 +69,7 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand const usage = chunk.usage as UnboundUsage if (delta?.content) { - yield { - type: "text", - text: delta.content, - } + yield { type: "text", text: delta.content } } if (usage) { @@ -126,10 +79,11 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand outputTokens: usage.completion_tokens || 0, } - // Only add cache tokens if they exist + // Only add cache tokens if they exist. if (usage.cache_creation_input_tokens) { usageData.cacheWriteTokens = usage.cache_creation_input_tokens } + if (usage.cache_read_input_tokens) { usageData.cacheReadTokens = usage.cache_read_input_tokens } @@ -139,94 +93,31 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand } } - override getModel(): { id: string; info: ModelInfo } { - const modelId = this.options.unboundModelId - const modelInfo = this.options.unboundModelInfo - if (modelId && modelInfo) { - return { id: modelId, info: modelInfo } - } - return { - id: unboundDefaultModelId, - info: unboundDefaultModelInfo, - } - } - async completePrompt(prompt: string): Promise { + const { id: modelId, info } = await this.fetchModel() + try { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: this.getModel().id.split("/")[1], + model: modelId.split("/")[1], messages: [{ role: "user", content: prompt }], } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? 0 } - if (this.getModel().id.startsWith("anthropic/")) { - requestOptions.max_tokens = this.getModel().info.maxTokens + if (modelId.startsWith("anthropic/")) { + requestOptions.max_tokens = info.maxTokens } - const response = await this.client.chat.completions.create(requestOptions, { - headers: { - "X-Unbound-Metadata": JSON.stringify({ - labels: [ - { - key: "app", - value: "roo-code", - }, - ], - }), - }, - }) + const response = await this.client.chat.completions.create(requestOptions, { headers: DEFAULT_HEADERS }) return response.choices[0]?.message.content || "" } catch (error) { if (error instanceof Error) { throw new Error(`Unbound completion error: ${error.message}`) } - throw error - } - } -} -export async function getUnboundModels() { - const models: Record = {} - - try { - const response = await axios.get("https://api.getunbound.ai/models") - - if (response.data) { - const rawModels: Record = response.data - - for (const [modelId, model] of Object.entries(rawModels)) { - const modelInfo: ModelInfo = { - maxTokens: model?.maxTokens ? parseInt(model.maxTokens) : undefined, - contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0, - supportsImages: model?.supportsImages ?? false, - supportsPromptCache: model?.supportsPromptCaching ?? false, - supportsComputerUse: model?.supportsComputerUse ?? false, - inputPrice: model?.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined, - outputPrice: model?.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined, - cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined, - cacheReadsPrice: model?.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined, - } - - switch (true) { - case modelId.startsWith("anthropic/"): - // Set max tokens to 8192 for supported Anthropic models - if (modelInfo.maxTokens !== 4096) { - modelInfo.maxTokens = 8192 - } - break - default: - break - } - - models[modelId] = modelInfo - } + throw error } - } catch (error) { - console.error(`Error fetching Unbound models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) } - - return models } diff --git a/src/api/transform/caching.ts b/src/api/transform/caching.ts new file mode 100644 index 0000000000..0a8ae6bf45 --- /dev/null +++ b/src/api/transform/caching.ts @@ -0,0 +1,36 @@ +import OpenAI from "openai" + +export const addCacheControlDirectives = (systemPrompt: string, messages: OpenAI.Chat.ChatCompletionMessageParam[]) => { + messages[0] = { + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + // @ts-ignore-next-line + cache_control: { type: "ephemeral" }, + }, + ], + } + + messages + .filter((msg) => msg.role === "user") + .slice(-2) + .forEach((msg) => { + if (typeof msg.content === "string") { + msg.content = [{ type: "text", text: msg.content }] + } + + if (Array.isArray(msg.content)) { + let lastTextPart = msg.content.filter((part) => part.type === "text").pop() + + if (!lastTextPart) { + lastTextPart = { type: "text", text: "..." } + msg.content.push(lastTextPart) + } + + // @ts-ignore-next-line + lastTextPart["cache_control"] = { type: "ephemeral" } + } + }) +} diff --git a/src/core/Cline.ts b/src/core/Cline.ts index 380938dc49..304b86d318 100644 --- a/src/core/Cline.ts +++ b/src/core/Cline.ts @@ -33,7 +33,6 @@ import { import { getApiMetrics } from "../shared/getApiMetrics" import { HistoryItem } from "../shared/HistoryItem" import { ClineAskResponse } from "../shared/WebviewMessage" -import { GlobalFileNames } from "../shared/globalFileNames" import { defaultModeSlug, getModeBySlug, getFullModeDetails, isToolAllowedForMode } from "../shared/modes" import { EXPERIMENT_IDS, experiments as Experiments, ExperimentId } from "../shared/experiments" import { formatLanguage } from "../shared/language" @@ -2101,7 +2100,7 @@ export class Cline extends EventEmitter { // Add this terminal's outputs to the details if (terminalOutputs.length > 0) { terminalDetails += `\n## Terminal ${inactiveTerminal.id}` - terminalOutputs.forEach((output, index) => { + terminalOutputs.forEach((output) => { terminalDetails += `\n### New Output\n${output}` }) } diff --git a/src/core/__tests__/Cline.test.ts b/src/core/__tests__/Cline.test.ts index cbad4b1f95..e12078d352 100644 --- a/src/core/__tests__/Cline.test.ts +++ b/src/core/__tests__/Cline.test.ts @@ -11,6 +11,7 @@ import { Cline } from "../Cline" import { ClineProvider } from "../webview/ClineProvider" import { ApiConfiguration, ModelInfo } from "../../shared/api" import { ApiStreamChunk } from "../../api/transform/stream" +import { ContextProxy } from "../config/ContextProxy" // Mock RooIgnoreController jest.mock("../ignore/RooIgnoreController") @@ -225,7 +226,12 @@ describe("Cline", () => { } // Setup mock provider with output channel - mockProvider = new ClineProvider(mockExtensionContext, mockOutputChannel) as jest.Mocked + mockProvider = new ClineProvider( + mockExtensionContext, + mockOutputChannel, + "sidebar", + new ContextProxy(mockExtensionContext), + ) as jest.Mocked // Setup mock API configuration mockApiConfig = { diff --git a/src/core/config/ContextProxy.ts b/src/core/config/ContextProxy.ts index aa40477ad8..dbab107d5b 100644 --- a/src/core/config/ContextProxy.ts +++ b/src/core/config/ContextProxy.ts @@ -256,4 +256,25 @@ export class ContextProxy { await this.initialize() } + + private static _instance: ContextProxy | null = null + + static get instance() { + if (!this._instance) { + throw new Error("ContextProxy not initialized") + } + + return this._instance + } + + static async getInstance(context: vscode.ExtensionContext) { + if (this._instance) { + return this._instance + } + + this._instance = new ContextProxy(context) + await this._instance.initialize() + + return this._instance + } } diff --git a/src/core/sliding-window/__tests__/sliding-window.test.ts b/src/core/sliding-window/__tests__/sliding-window.test.ts index fb7bd9c227..16af2d4630 100644 --- a/src/core/sliding-window/__tests__/sliding-window.test.ts +++ b/src/core/sliding-window/__tests__/sliding-window.test.ts @@ -234,7 +234,6 @@ describe("truncateConversationIfNeeded", () => { it("should not truncate if tokens are below max tokens threshold", async () => { const modelInfo = createModelInfo(100000, 30000) - const maxTokens = 100000 - 30000 // 70000 const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE // 10000 const totalTokens = 70000 - dynamicBuffer - 1 // Just below threshold - buffer @@ -253,7 +252,6 @@ describe("truncateConversationIfNeeded", () => { it("should truncate if tokens are above max tokens threshold", async () => { const modelInfo = createModelInfo(100000, 30000) - const maxTokens = 100000 - 30000 // 70000 const totalTokens = 70001 // Above threshold // Create messages with very small content in the last one to avoid token overflow @@ -393,7 +391,6 @@ describe("truncateConversationIfNeeded", () => { it("should truncate if tokens are within TOKEN_BUFFER_PERCENTAGE of the threshold", async () => { const modelInfo = createModelInfo(100000, 30000) - const maxTokens = 100000 - 30000 // 70000 const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE // 10% of 100000 = 10000 const totalTokens = 70000 - dynamicBuffer + 1 // Just within the dynamic buffer of threshold (70000) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index bf5901b817..b27853e35a 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -15,20 +15,16 @@ import { setPanel } from "../../activate/registerCommands" import { ApiConfiguration, ApiProvider, - ModelInfo, requestyDefaultModelId, - requestyDefaultModelInfo, openRouterDefaultModelId, - openRouterDefaultModelInfo, glamaDefaultModelId, - glamaDefaultModelInfo, } from "../../shared/api" import { findLast } from "../../shared/array" import { supportPrompt } from "../../shared/support-prompt" import { GlobalFileNames } from "../../shared/globalFileNames" import { HistoryItem } from "../../shared/HistoryItem" import { ExtensionMessage } from "../../shared/ExtensionMessage" -import { Mode, PromptComponent, defaultModeSlug, getModeBySlug, getGroupName } from "../../shared/modes" +import { Mode, PromptComponent, defaultModeSlug } from "../../shared/modes" import { experimentDefault } from "../../shared/experiments" import { formatLanguage } from "../../shared/language" import { Terminal, TERMINAL_SHELL_INTEGRATION_TIMEOUT } from "../../integrations/terminal/Terminal" @@ -80,7 +76,6 @@ export class ClineProvider extends EventEmitter implements public isViewLaunched = false public settingsImportedAt?: number public readonly latestAnnouncementId = "apr-23-2025-3-14" // Update for v3.14.0 announcement - public readonly contextProxy: ContextProxy public readonly providerSettingsManager: ProviderSettingsManager public readonly customModesManager: CustomModesManager @@ -88,11 +83,11 @@ export class ClineProvider extends EventEmitter implements readonly context: vscode.ExtensionContext, private readonly outputChannel: vscode.OutputChannel, private readonly renderContext: "sidebar" | "editor" = "sidebar", + public readonly contextProxy: ContextProxy, ) { super() this.log("ClineProvider instantiated") - this.contextProxy = new ContextProxy(context) ClineProvider.activeInstances.add(this) // Register this provider with the telemetry service to enable it to add @@ -939,29 +934,6 @@ export class ClineProvider extends EventEmitter implements return getSettingsDirectoryPath(globalStoragePath) } - private async ensureCacheDirectoryExists() { - const { getCacheDirectoryPath } = await import("../../shared/storagePathManager") - const globalStoragePath = this.contextProxy.globalStorageUri.fsPath - return getCacheDirectoryPath(globalStoragePath) - } - - async writeModelsToCache(filename: string, data: T) { - const cacheDir = await this.ensureCacheDirectoryExists() - await fs.writeFile(path.join(cacheDir, filename), JSON.stringify(data)) - } - - async readModelsFromCache(filename: string): Promise | undefined> { - const filePath = path.join(await this.ensureCacheDirectoryExists(), filename) - const fileExists = await fileExistsAtPath(filePath) - - if (fileExists) { - const fileContents = await fs.readFile(filePath, "utf8") - return JSON.parse(fileContents) - } - - return undefined - } - // OpenRouter async handleOpenRouterCallback(code: string) { @@ -990,7 +962,6 @@ export class ClineProvider extends EventEmitter implements apiProvider: "openrouter", openRouterApiKey: apiKey, openRouterModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId, - openRouterModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo, } await this.upsertApiConfiguration(currentApiConfigName, newConfiguration) @@ -1021,7 +992,6 @@ export class ClineProvider extends EventEmitter implements apiProvider: "glama", glamaApiKey: apiKey, glamaModelId: apiConfiguration?.glamaModelId || glamaDefaultModelId, - glamaModelInfo: apiConfiguration?.glamaModelInfo || glamaDefaultModelInfo, } await this.upsertApiConfiguration(currentApiConfigName, newConfiguration) @@ -1037,7 +1007,6 @@ export class ClineProvider extends EventEmitter implements apiProvider: "requesty", requestyApiKey: code, requestyModelId: apiConfiguration?.requestyModelId || requestyDefaultModelId, - requestyModelInfo: apiConfiguration?.requestyModelInfo || requestyDefaultModelInfo, } await this.upsertApiConfiguration(currentApiConfigName, newConfiguration) diff --git a/src/core/webview/__tests__/ClineProvider.test.ts b/src/core/webview/__tests__/ClineProvider.test.ts index 5d6067bbf5..4c5a28fa4d 100644 --- a/src/core/webview/__tests__/ClineProvider.test.ts +++ b/src/core/webview/__tests__/ClineProvider.test.ts @@ -9,6 +9,7 @@ import { setSoundEnabled } from "../../../utils/sound" import { setTtsEnabled } from "../../../utils/tts" import { defaultModeSlug } from "../../../shared/modes" import { experimentDefault } from "../../../shared/experiments" +import { ContextProxy } from "../../config/ContextProxy" // Mock setup must come before imports jest.mock("../../prompts/sections/custom-instructions") @@ -307,6 +308,7 @@ describe("ClineProvider", () => { // Mock webview mockPostMessage = jest.fn() + mockWebviewView = { webview: { postMessage: mockPostMessage, @@ -325,7 +327,7 @@ describe("ClineProvider", () => { }), } as unknown as vscode.WebviewView - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) // @ts-ignore - Access private property for testing updateGlobalStateSpy = jest.spyOn(provider.contextProxy, "setValue") @@ -357,6 +359,8 @@ describe("ClineProvider", () => { provider = new ClineProvider( { ...mockContext, extensionMode: vscode.ExtensionMode.Development }, mockOutputChannel, + "sidebar", + new ContextProxy(mockContext), ) ;(axios.get as jest.Mock).mockRejectedValueOnce(new Error("Network error")) @@ -810,7 +814,6 @@ describe("ClineProvider", () => { const modeCustomInstructions = "Code mode instructions" const mockApiConfig = { apiProvider: "openrouter", - openRouterModelInfo: { supportsComputerUse: true }, } jest.spyOn(provider, "getState").mockResolvedValue({ @@ -906,7 +909,7 @@ describe("ClineProvider", () => { } as unknown as vscode.ExtensionContext // Create new provider with updated mock context - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) await provider.resolveWebviewView(mockWebviewView) const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] @@ -1069,16 +1072,6 @@ describe("ClineProvider", () => { jest.spyOn(provider, "getState").mockResolvedValue({ apiConfiguration: { apiProvider: "openrouter" as const, - openRouterModelInfo: { - supportsComputerUse: true, - supportsPromptCache: false, - maxTokens: 4096, - contextWindow: 8192, - supportsImages: false, - inputPrice: 0.0, - outputPrice: 0.0, - description: undefined, - }, }, mcpEnabled: true, enableMcpServerCreation: false, @@ -1102,16 +1095,6 @@ describe("ClineProvider", () => { jest.spyOn(provider, "getState").mockResolvedValue({ apiConfiguration: { apiProvider: "openrouter" as const, - openRouterModelInfo: { - supportsComputerUse: true, - supportsPromptCache: false, - maxTokens: 4096, - contextWindow: 8192, - supportsImages: false, - inputPrice: 0.0, - outputPrice: 0.0, - description: undefined, - }, }, mcpEnabled: false, enableMcpServerCreation: false, @@ -1184,7 +1167,6 @@ describe("ClineProvider", () => { apiConfiguration: { apiProvider: "openrouter", apiModelId: "test-model", - openRouterModelInfo: { supportsComputerUse: true }, }, customModePrompts: {}, mode: "code", @@ -1241,7 +1223,6 @@ describe("ClineProvider", () => { apiConfiguration: { apiProvider: "openrouter", apiModelId: "test-model", - openRouterModelInfo: { supportsComputerUse: true }, }, customModePrompts: {}, mode: "code", @@ -1282,7 +1263,6 @@ describe("ClineProvider", () => { jest.spyOn(provider, "getState").mockResolvedValue({ apiConfiguration: { apiProvider: "openrouter", - openRouterModelInfo: { supportsComputerUse: true }, }, customModePrompts: { architect: { customInstructions: "Architect mode instructions" }, @@ -1973,7 +1953,7 @@ describe("Project MCP Settings", () => { onDidChangeVisibility: jest.fn(), } as unknown as vscode.WebviewView - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) }) test("handles openProjectMcpSettings message", async () => { @@ -2068,10 +2048,8 @@ describe.skip("ContextProxy integration", () => { } as unknown as vscode.ExtensionContext mockOutputChannel = { appendLine: jest.fn() } as unknown as vscode.OutputChannel - provider = new ClineProvider(mockContext, mockOutputChannel) - - // @ts-ignore - accessing private property for testing - mockContextProxy = provider.contextProxy + mockContextProxy = new ContextProxy(mockContext) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", mockContextProxy) mockGlobalStateUpdate = mockContext.globalState.update as jest.Mock }) @@ -2131,7 +2109,7 @@ describe("getTelemetryProperties", () => { } as unknown as vscode.ExtensionContext mockOutputChannel = { appendLine: jest.fn() } as unknown as vscode.OutputChannel - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) // Setup Cline instance with mocked getModel method const { Cline } = require("../../Cline") diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index a62c1d9410..a7cfb20c6f 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -8,7 +8,6 @@ import { Language, ApiConfigMeta } from "../../schemas" import { changeLanguage, t } from "../../i18n" import { ApiConfiguration } from "../../shared/api" import { supportPrompt } from "../../shared/support-prompt" -import { GlobalFileNames } from "../../shared/globalFileNames" import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage" import { checkExistKey } from "../../shared/checkExistApiConfig" @@ -25,10 +24,6 @@ import { playTts, setTtsEnabled, setTtsSpeed, stopTts } from "../../utils/tts" import { singleCompletionHandler } from "../../utils/single-completion-handler" import { searchCommits } from "../../utils/git" import { exportSettings, importSettings } from "../config/importExport" -import { getOpenRouterModels } from "../../api/providers/fetchers/openrouter" -import { getGlamaModels } from "../../api/providers/glama" -import { getUnboundModels } from "../../api/providers/unbound" -import { getRequestyModels } from "../../api/providers/requesty" import { getOpenAiModels } from "../../api/providers/openai" import { getOllamaModels } from "../../api/providers/ollama" import { getVsCodeLmModels } from "../../api/providers/vscode-lm" @@ -42,6 +37,7 @@ import { SYSTEM_PROMPT } from "../prompts/system" import { buildApiHandler } from "../../api" import { GlobalState } from "../../schemas" import { MultiSearchReplaceDiffStrategy } from "../diff/strategies/multi-search-replace" +import { getModels } from "../../api/providers/fetchers/cache" export const webviewMessageHandler = async (provider: ClineProvider, message: WebviewMessage) => { // Utility functions provided for concise get/update of global state via contextProxy API. @@ -56,104 +52,18 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We await updateGlobalState("customModes", customModes) provider.postStateToWebview() - provider.workspaceTracker?.initializeFilePaths() // don't await + provider.workspaceTracker?.initializeFilePaths() // Don't await. getTheme().then((theme) => provider.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) })) - // If MCP Hub is already initialized, update the webview with current server list + // If MCP Hub is already initialized, update the webview with + // current server list. const mcpHub = provider.getMcpHub() + if (mcpHub) { - provider.postMessageToWebview({ - type: "mcpServers", - mcpServers: mcpHub.getAllServers(), - }) + provider.postMessageToWebview({ type: "mcpServers", mcpServers: mcpHub.getAllServers() }) } - // Post last cached models in case the call to endpoint fails. - provider.readModelsFromCache(GlobalFileNames.openRouterModels).then((openRouterModels) => { - if (openRouterModels) { - provider.postMessageToWebview({ type: "openRouterModels", openRouterModels }) - } - }) - - // GUI relies on model info to be up-to-date to provide - // the most accurate pricing, so we need to fetch the - // latest details on launch. - // We do this for all users since many users switch - // between api providers and if they were to switch back - // to OpenRouter it would be showing outdated model info - // if we hadn't retrieved the latest at this point - // (see normalizeApiConfiguration > openrouter). - const { apiConfiguration: currentApiConfig } = await provider.getState() - - getOpenRouterModels(currentApiConfig).then(async (openRouterModels) => { - if (Object.keys(openRouterModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.openRouterModels, openRouterModels) - await provider.postMessageToWebview({ type: "openRouterModels", openRouterModels }) - - // Update model info in state (this needs to be - // done here since we don't want to update state - // while settings is open, and we may refresh - // models there). - const { apiConfiguration } = await provider.getState() - - if (apiConfiguration.openRouterModelId) { - await updateGlobalState( - "openRouterModelInfo", - openRouterModels[apiConfiguration.openRouterModelId], - ) - - await provider.postStateToWebview() - } - } - }) - - provider.readModelsFromCache(GlobalFileNames.glamaModels).then((glamaModels) => { - if (glamaModels) { - provider.postMessageToWebview({ type: "glamaModels", glamaModels }) - } - }) - - getGlamaModels().then(async (glamaModels) => { - if (Object.keys(glamaModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.glamaModels, glamaModels) - await provider.postMessageToWebview({ type: "glamaModels", glamaModels }) - - const { apiConfiguration } = await provider.getState() - - if (apiConfiguration.glamaModelId) { - await updateGlobalState("glamaModelInfo", glamaModels[apiConfiguration.glamaModelId]) - await provider.postStateToWebview() - } - } - }) - - provider.readModelsFromCache(GlobalFileNames.unboundModels).then((unboundModels) => { - if (unboundModels) { - provider.postMessageToWebview({ type: "unboundModels", unboundModels }) - } - }) - - getUnboundModels().then(async (unboundModels) => { - if (Object.keys(unboundModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.unboundModels, unboundModels) - await provider.postMessageToWebview({ type: "unboundModels", unboundModels }) - - const { apiConfiguration } = await provider.getState() - - if (apiConfiguration?.unboundModelId) { - await updateGlobalState("unboundModelInfo", unboundModels[apiConfiguration.unboundModelId]) - await provider.postStateToWebview() - } - } - }) - - provider.readModelsFromCache(GlobalFileNames.requestyModels).then((requestyModels) => { - if (requestyModels) { - provider.postMessageToWebview({ type: "requestyModels", requestyModels }) - } - }) - provider.providerSettingsManager .listConfig() .then(async (listApiConfig) => { @@ -371,51 +281,32 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We case "resetState": await provider.resetState() break - case "refreshOpenRouterModels": { - const { apiConfiguration: configForRefresh } = await provider.getState() - const openRouterModels = await getOpenRouterModels(configForRefresh) - - if (Object.keys(openRouterModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.openRouterModels, openRouterModels) - await provider.postMessageToWebview({ type: "openRouterModels", openRouterModels }) - } - - break - } - case "refreshGlamaModels": - const glamaModels = await getGlamaModels() - - if (Object.keys(glamaModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.glamaModels, glamaModels) - await provider.postMessageToWebview({ type: "glamaModels", glamaModels }) - } - - break - case "refreshUnboundModels": - const unboundModels = await getUnboundModels() - - if (Object.keys(unboundModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.unboundModels, unboundModels) - await provider.postMessageToWebview({ type: "unboundModels", unboundModels }) - } - - break - case "refreshRequestyModels": - const requestyModels = await getRequestyModels(message.values?.apiKey) - - if (Object.keys(requestyModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.requestyModels, requestyModels) - await provider.postMessageToWebview({ type: "requestyModels", requestyModels }) - } - + case "requestRouterModels": + const [openRouterModels, requestyModels, glamaModels, unboundModels] = await Promise.all([ + getModels("openrouter"), + getModels("requesty"), + getModels("glama"), + getModels("unbound"), + ]) + + provider.postMessageToWebview({ + type: "routerModels", + routerModels: { + openrouter: openRouterModels, + requesty: requestyModels, + glama: glamaModels, + unbound: unboundModels, + }, + }) break - case "refreshOpenAiModels": + case "requestOpenAiModels": if (message?.values?.baseUrl && message?.values?.apiKey) { const openAiModels = await getOpenAiModels( message?.values?.baseUrl, message?.values?.apiKey, message?.values?.hostHeader, ) + provider.postMessageToWebview({ type: "openAiModels", openAiModels }) } @@ -1413,5 +1304,6 @@ const generateSystemPrompt = async (provider: ClineProvider, message: WebviewMes language, rooIgnoreInstructions, ) + return systemPrompt } diff --git a/src/exports/api.ts b/src/exports/api.ts index 46eee5dc59..0d70d7dc04 100644 --- a/src/exports/api.ts +++ b/src/exports/api.ts @@ -6,7 +6,7 @@ import * as path from "path" import { getWorkspacePath } from "../utils/path" import { ClineProvider } from "../core/webview/ClineProvider" import { openClineInNewTab } from "../activate/registerCommands" -import { RooCodeSettings, RooCodeEvents, RooCodeEventName, ClineMessage } from "../schemas" +import { RooCodeSettings, RooCodeEvents, RooCodeEventName } from "../schemas" import { IpcOrigin, IpcMessageType, TaskCommandName, TaskEvent } from "../schemas/ipc" import { RooCodeAPI } from "./interface" diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index 894b776985..3eb31d3b9b 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -28,69 +28,9 @@ type ProviderSettings = { anthropicBaseUrl?: string | undefined anthropicUseAuthToken?: boolean | undefined glamaModelId?: string | undefined - glamaModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined glamaApiKey?: string | undefined openRouterApiKey?: string | undefined openRouterModelId?: string | undefined - openRouterModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined openRouterBaseUrl?: string | undefined openRouterSpecificProvider?: string | undefined openRouterUseMiddleOutTransform?: boolean | undefined @@ -170,68 +110,8 @@ type ProviderSettings = { deepSeekApiKey?: string | undefined unboundApiKey?: string | undefined unboundModelId?: string | undefined - unboundModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined requestyApiKey?: string | undefined requestyModelId?: string | undefined - requestyModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined xaiApiKey?: string | undefined modelMaxTokens?: number | undefined modelMaxThinkingTokens?: number | undefined diff --git a/src/exports/types.ts b/src/exports/types.ts index 4f394c2974..cda03c83d4 100644 --- a/src/exports/types.ts +++ b/src/exports/types.ts @@ -29,69 +29,9 @@ type ProviderSettings = { anthropicBaseUrl?: string | undefined anthropicUseAuthToken?: boolean | undefined glamaModelId?: string | undefined - glamaModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined glamaApiKey?: string | undefined openRouterApiKey?: string | undefined openRouterModelId?: string | undefined - openRouterModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined openRouterBaseUrl?: string | undefined openRouterSpecificProvider?: string | undefined openRouterUseMiddleOutTransform?: boolean | undefined @@ -171,68 +111,8 @@ type ProviderSettings = { deepSeekApiKey?: string | undefined unboundApiKey?: string | undefined unboundModelId?: string | undefined - unboundModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined requestyApiKey?: string | undefined requestyModelId?: string | undefined - requestyModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined xaiApiKey?: string | undefined modelMaxTokens?: number | undefined modelMaxThinkingTokens?: number | undefined diff --git a/src/extension.ts b/src/extension.ts index aa834c560e..d895bb0e1b 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -15,6 +15,7 @@ try { import "./utils/path" // Necessary to have access to String.prototype.toPosix. import { initializeI18n } from "./i18n" +import { ContextProxy } from "./core/config/ContextProxy" import { ClineProvider } from "./core/webview/ClineProvider" import { CodeActionProvider } from "./core/CodeActionProvider" import { DIFF_VIEW_URI_SCHEME } from "./integrations/editor/DiffViewProvider" @@ -66,7 +67,8 @@ export async function activate(context: vscode.ExtensionContext) { context.globalState.update("allowedCommands", defaultCommands) } - const provider = new ClineProvider(context, outputChannel, "sidebar") + const contextProxy = await ContextProxy.getInstance(context) + const provider = new ClineProvider(context, outputChannel, "sidebar", contextProxy) telemetryService.setProvider(provider) context.subscriptions.push( diff --git a/src/i18n/setup.ts b/src/i18n/setup.ts index 058f357b46..82cb2bf910 100644 --- a/src/i18n/setup.ts +++ b/src/i18n/setup.ts @@ -6,17 +6,6 @@ const translations: Record> = {} // Determine if running in test environment (jest) const isTestEnv = process.env.NODE_ENV === "test" || process.env.JEST_WORKER_ID !== undefined -// Detect environment - browser vs Node.js -const isBrowser = typeof window !== "undefined" && typeof window.document !== "undefined" - -// Define interface for VSCode extension process -interface VSCodeProcess extends NodeJS.Process { - resourcesPath?: string -} - -// Type cast process to custom interface with resourcesPath -const vscodeProcess = process as VSCodeProcess - // Load translations based on environment if (!isTestEnv) { try { diff --git a/src/schemas/index.ts b/src/schemas/index.ts index 2dc4d3f6a1..a6b2cabcda 100644 --- a/src/schemas/index.ts +++ b/src/schemas/index.ts @@ -321,12 +321,10 @@ export const providerSettingsSchema = z.object({ anthropicUseAuthToken: z.boolean().optional(), // Glama glamaModelId: z.string().optional(), - glamaModelInfo: modelInfoSchema.nullish(), glamaApiKey: z.string().optional(), // OpenRouter openRouterApiKey: z.string().optional(), openRouterModelId: z.string().optional(), - openRouterModelInfo: modelInfoSchema.nullish(), openRouterBaseUrl: z.string().optional(), openRouterSpecificProvider: z.string().optional(), openRouterUseMiddleOutTransform: z.boolean().optional(), @@ -388,11 +386,9 @@ export const providerSettingsSchema = z.object({ // Unbound unboundApiKey: z.string().optional(), unboundModelId: z.string().optional(), - unboundModelInfo: modelInfoSchema.nullish(), // Requesty requestyApiKey: z.string().optional(), requestyModelId: z.string().optional(), - requestyModelInfo: modelInfoSchema.nullish(), // X.AI (Grok) xaiApiKey: z.string().optional(), // Claude 3.7 Sonnet Thinking @@ -423,12 +419,10 @@ const providerSettingsRecord: ProviderSettingsRecord = { anthropicUseAuthToken: undefined, // Glama glamaModelId: undefined, - glamaModelInfo: undefined, glamaApiKey: undefined, // OpenRouter openRouterApiKey: undefined, openRouterModelId: undefined, - openRouterModelInfo: undefined, openRouterBaseUrl: undefined, openRouterSpecificProvider: undefined, openRouterUseMiddleOutTransform: undefined, @@ -482,11 +476,9 @@ const providerSettingsRecord: ProviderSettingsRecord = { // Unbound unboundApiKey: undefined, unboundModelId: undefined, - unboundModelInfo: undefined, // Requesty requestyApiKey: undefined, requestyModelId: undefined, - requestyModelInfo: undefined, // Claude 3.7 Sonnet Thinking modelMaxTokens: undefined, modelMaxThinkingTokens: undefined, diff --git a/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts b/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts index 6e42cfae07..84589c5fd2 100644 --- a/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts +++ b/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts @@ -12,6 +12,8 @@ import * as fileSearch from "../../../services/search/file-search" import { RepoPerTaskCheckpointService } from "../RepoPerTaskCheckpointService" +jest.setTimeout(10_000) + const tmpDir = path.join(os.tmpdir(), "CheckpointService") const initWorkspaceRepo = async ({ diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index b942188345..00b2cc5194 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -1,5 +1,6 @@ +import { GitCommit } from "../utils/git" + import { - ModelInfo, GlobalSettings, ApiConfigMeta, ProviderSettings as ApiConfiguration, @@ -13,8 +14,8 @@ import { ClineMessage, } from "../schemas" import { McpServer } from "./mcp" -import { GitCommit } from "../utils/git" import { Mode } from "./modes" +import { RouterModels } from "./api" export type { ApiConfigMeta, ToolProgressStatus } @@ -33,24 +34,20 @@ export interface ExtensionMessage { | "action" | "state" | "selectedImages" - | "ollamaModels" - | "lmStudioModels" | "theme" | "workspaceUpdated" | "invoke" | "partialMessage" - | "openRouterModels" - | "glamaModels" - | "unboundModels" - | "requestyModels" - | "openAiModels" | "mcpServers" | "enhancedPrompt" | "commitSearchResults" | "listApiConfig" + | "routerModels" + | "openAiModels" + | "ollamaModels" + | "lmStudioModels" | "vsCodeLmModels" | "vsCodeLmApiAvailable" - | "requestVsCodeLmModels" | "updatePrompt" | "systemPrompt" | "autoApprovalEnabled" @@ -81,9 +78,6 @@ export interface ExtensionMessage { invoke?: "newChat" | "sendMessage" | "primaryButtonClick" | "secondaryButtonClick" | "setChatBoxMessage" state?: ExtensionState images?: string[] - ollamaModels?: string[] - lmStudioModels?: string[] - vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] filePaths?: string[] openedTabs?: Array<{ label: string @@ -91,11 +85,11 @@ export interface ExtensionMessage { path?: string }> partialMessage?: ClineMessage - openRouterModels?: Record - glamaModels?: Record - unboundModels?: Record - requestyModels?: Record + routerModels?: RouterModels openAiModels?: string[] + ollamaModels?: string[] + lmStudioModels?: string[] + vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] mcpServers?: McpServer[] commits?: GitCommit[] listApiConfig?: ApiConfigMeta[] @@ -106,11 +100,7 @@ export interface ExtensionMessage { values?: Record requestId?: string promptText?: string - results?: Array<{ - path: string - type: "file" | "folder" - label?: string - }> + results?: { path: string; type: "file" | "folder"; label?: string }[] error?: string } diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 6b5c111f7a..96943ed8f9 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -1,5 +1,6 @@ import { z } from "zod" -import { ApiConfiguration, ApiProvider } from "./api" + +import { ApiConfiguration } from "./api" import { Mode, PromptComponent, ModeConfig } from "./modes" export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse" @@ -40,17 +41,15 @@ export interface WebviewMessage { | "importSettings" | "exportSettings" | "resetState" + | "requestRouterModels" + | "requestOpenAiModels" | "requestOllamaModels" | "requestLmStudioModels" + | "requestVsCodeLmModels" | "openImage" | "openFile" | "openMention" | "cancelTask" - | "refreshOpenRouterModels" - | "refreshGlamaModels" - | "refreshUnboundModels" - | "refreshRequestyModels" - | "refreshOpenAiModels" | "alwaysAllowBrowser" | "alwaysAllowMcp" | "alwaysAllowModeSwitch" @@ -94,7 +93,6 @@ export interface WebviewMessage { | "alwaysApproveResubmit" | "requestDelaySeconds" | "setApiConfigPassword" - | "requestVsCodeLmModels" | "mode" | "updatePrompt" | "updateSupportPrompt" diff --git a/src/shared/api.ts b/src/shared/api.ts index b5d38e421b..2559232c11 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1446,3 +1446,13 @@ export const COMPUTER_USE_MODELS = new Set([ "anthropic/claude-3.7-sonnet:beta", "anthropic/claude-3.7-sonnet:thinking", ]) + +const routerNames = ["openrouter", "requesty", "glama", "unbound"] as const + +export type RouterName = (typeof routerNames)[number] + +export const isRouterName = (value: string): value is RouterName => routerNames.includes(value as RouterName) + +export type ModelRecord = Record + +export type RouterModels = Record diff --git a/src/shared/globalFileNames.ts b/src/shared/globalFileNames.ts index 68990dfe95..b82ea9b00e 100644 --- a/src/shared/globalFileNames.ts +++ b/src/shared/globalFileNames.ts @@ -1,11 +1,7 @@ export const GlobalFileNames = { apiConversationHistory: "api_conversation_history.json", uiMessages: "ui_messages.json", - glamaModels: "glama_models.json", - openRouterModels: "openrouter_models.json", - requestyModels: "requesty_models.json", mcpSettings: "mcp_settings.json", - unboundModels: "unbound_models.json", customModes: "custom_modes.json", taskMetadata: "task_metadata.json", } diff --git a/webview-ui/.eslintrc.json b/webview-ui/.eslintrc.json index 4309a7c3ed..8ca877d87c 100644 --- a/webview-ui/.eslintrc.json +++ b/webview-ui/.eslintrc.json @@ -1,4 +1,8 @@ { "extends": "react-app", - "ignorePatterns": ["!.storybook"] + "ignorePatterns": ["!.storybook"], + "rules": { + "no-unused-vars": "off", + "@typescript-eslint/no-unused-vars": ["error", { "varsIgnorePattern": "^_", "argsIgnorePattern": "^_" }] + } } diff --git a/webview-ui/src/__tests__/ContextWindowProgress.test.tsx b/webview-ui/src/__tests__/ContextWindowProgress.test.tsx index bd26767db1..c43ecd5d5e 100644 --- a/webview-ui/src/__tests__/ContextWindowProgress.test.tsx +++ b/webview-ui/src/__tests__/ContextWindowProgress.test.tsx @@ -2,6 +2,8 @@ import { render, screen } from "@testing-library/react" import "@testing-library/jest-dom" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + import TaskHeader from "@src/components/chat/TaskHeader" // Mock formatLargeNumber function @@ -17,21 +19,15 @@ jest.mock("@vscode/webview-ui-toolkit/react", () => ({ // Mock ExtensionStateContext since we use useExtensionState jest.mock("@src/context/ExtensionStateContext", () => ({ useExtensionState: jest.fn(() => ({ - apiConfiguration: { - apiProvider: "openai", - // Add other needed properties - }, - currentTaskItem: { - id: "test-id", - number: 1, - size: 1024, - }, + apiConfiguration: { apiProvider: "openai" }, + currentTaskItem: { id: "test-id", number: 1, size: 1024 }, })), })) // Mock highlighting function to avoid JSX parsing issues in tests jest.mock("@src/components/chat/TaskHeader", () => { const originalModule = jest.requireActual("@src/components/chat/TaskHeader") + return { __esModule: true, ...originalModule, @@ -39,19 +35,21 @@ jest.mock("@src/components/chat/TaskHeader", () => { } }) +// Mock useSelectedModel hook +jest.mock("@src/components/ui/hooks/useSelectedModel", () => ({ + useSelectedModel: jest.fn(() => ({ + info: { contextWindow: 4000 }, + })), +})) + describe("ContextWindowProgress", () => { + const queryClient = new QueryClient() + // Helper function to render just the ContextWindowProgress part through TaskHeader const renderComponent = (props: Record) => { // Create a simple mock of the task that avoids importing the actual types - const defaultTask = { - ts: Date.now(), - type: "say" as const, - say: "task" as const, - text: "Test task", - } - const defaultProps = { - task: defaultTask, + task: { ts: Date.now(), type: "say" as const, say: "task" as const, text: "Test task" }, tokensIn: 100, tokensOut: 50, doesModelSupportPromptCache: true, @@ -60,18 +58,17 @@ describe("ContextWindowProgress", () => { onClose: jest.fn(), } - return render() + return render( + + + , + ) } - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) - test("renders correctly with valid inputs", () => { - renderComponent({ - contextTokens: 1000, - contextWindow: 4000, - }) + it("renders correctly with valid inputs", () => { + renderComponent({ contextTokens: 1000, contextWindow: 4000 }) // Check for basic elements // The context-window-label is not part of the ContextWindowProgress component @@ -83,11 +80,8 @@ describe("ContextWindowProgress", () => { expect(screen.getByTestId("context-window-size")).toHaveTextContent(/(4000|128000)/) // contextWindow }) - test("handles zero context window gracefully", () => { - renderComponent({ - contextTokens: 0, - contextWindow: 0, - }) + it("handles zero context window gracefully", () => { + renderComponent({ contextTokens: 0, contextWindow: 0 }) // In the current implementation, the component is still displayed with zero values // rather than being hidden completely @@ -96,11 +90,8 @@ describe("ContextWindowProgress", () => { expect(screen.getByTestId("context-tokens-count")).toHaveTextContent("0") }) - test("handles edge cases with negative values", () => { - renderComponent({ - contextTokens: -100, // Should be treated as 0 - contextWindow: 4000, - }) + it("handles edge cases with negative values", () => { + renderComponent({ contextTokens: -100, contextWindow: 4000 }) // Should show 0 instead of -100 expect(screen.getByTestId("context-tokens-count")).toHaveTextContent("0") @@ -108,14 +99,9 @@ describe("ContextWindowProgress", () => { expect(screen.getByTestId("context-window-size")).toHaveTextContent(/(4000|128000)/) }) - test("calculates percentages correctly", () => { - const contextTokens = 1000 - const contextWindow = 4000 + it("calculates percentages correctly", () => { + renderComponent({ contextTokens: 1000, contextWindow: 4000 }) - renderComponent({ - contextTokens, - contextWindow, - }) // Instead of checking the title attribute, verify the data-test-id // which identifies the element containing info about the percentage of tokens used const tokenUsageDiv = screen.getByTestId("context-tokens-used") diff --git a/webview-ui/src/components/chat/ChatView.tsx b/webview-ui/src/components/chat/ChatView.tsx index 419537b361..25dffe9dcb 100644 --- a/webview-ui/src/components/chat/ChatView.tsx +++ b/webview-ui/src/components/chat/ChatView.tsx @@ -24,7 +24,7 @@ import { getAllModes } from "@roo/shared/modes" import { useExtensionState } from "@src/context/ExtensionStateContext" import { vscode } from "@src/utils/vscode" -import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration" +import { useSelectedModel } from "@/components/ui/hooks/useSelectedModel" import { validateCommand } from "@src/utils/command-validation" import { useAppTranslation } from "@src/i18n/TranslationContext" @@ -40,7 +40,7 @@ import TaskHeader from "./TaskHeader" import AutoApproveMenu from "./AutoApproveMenu" import SystemPromptWarning from "./SystemPromptWarning" -interface ChatViewProps { +export interface ChatViewProps { isHidden: boolean showAnnouncement: boolean hideAnnouncement: () => void @@ -490,16 +490,14 @@ const ChatViewComponent: React.ForwardRefRenderFunction { - return normalizeApiConfiguration(apiConfiguration) - }, [apiConfiguration]) + const { info: model } = useSelectedModel(apiConfiguration) const selectImages = useCallback(() => { vscode.postMessage({ type: "selectImages" }) }, []) const shouldDisableImages = - !selectedModelInfo.supportsImages || textAreaDisabled || selectedImages.length >= MAX_IMAGES_PER_MESSAGE + !model?.supportsImages || textAreaDisabled || selectedImages.length >= MAX_IMAGES_PER_MESSAGE const handleMessage = useCallback( (e: MessageEvent) => { @@ -1216,7 +1214,7 @@ const ChatViewComponent: React.ForwardRefRenderFunction handlePrimaryButtonClick(inputValue, selectedImages)}> + onClick={() => handlePrimaryButtonClick(inputValue, selectedImages)}> {primaryButtonText} )} @@ -1389,7 +1387,7 @@ const ChatViewComponent: React.ForwardRefRenderFunction handleSecondaryButtonClick(inputValue, selectedImages)}> + onClick={() => handleSecondaryButtonClick(inputValue, selectedImages)}> {isStreaming ? t("chat:cancel.title") : secondaryButtonText} )} diff --git a/webview-ui/src/components/chat/ContextWindowProgress.tsx b/webview-ui/src/components/chat/ContextWindowProgress.tsx index 8dc69432f6..a5490d9d4f 100644 --- a/webview-ui/src/components/chat/ContextWindowProgress.tsx +++ b/webview-ui/src/components/chat/ContextWindowProgress.tsx @@ -12,6 +12,7 @@ interface ContextWindowProgressProps { export const ContextWindowProgress = ({ contextWindow, contextTokens, maxTokens }: ContextWindowProgressProps) => { const { t } = useTranslation() + // Use the shared utility function to calculate all token distribution values const tokenDistribution = useMemo( () => calculateTokenDistribution(contextWindow, contextTokens, maxTokens), diff --git a/webview-ui/src/components/chat/TaskHeader.tsx b/webview-ui/src/components/chat/TaskHeader.tsx index 9fe3cf5c9d..0b39f62edc 100644 --- a/webview-ui/src/components/chat/TaskHeader.tsx +++ b/webview-ui/src/components/chat/TaskHeader.tsx @@ -1,4 +1,4 @@ -import { memo, useMemo, useRef, useState } from "react" +import { memo, useRef, useState } from "react" import { useWindowSize } from "react-use" import { useTranslation } from "react-i18next" import { VSCodeBadge } from "@vscode/webview-ui-toolkit/react" @@ -11,7 +11,7 @@ import { formatLargeNumber } from "@src/utils/format" import { cn } from "@src/lib/utils" import { Button } from "@src/components/ui" import { useExtensionState } from "@src/context/ExtensionStateContext" -import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration" +import { useSelectedModel } from "@/components/ui/hooks/useSelectedModel" import Thumbnails from "../common/Thumbnails" @@ -19,7 +19,7 @@ import { TaskActions } from "./TaskActions" import { ContextWindowProgress } from "./ContextWindowProgress" import { Mention } from "./Mention" -interface TaskHeaderProps { +export interface TaskHeaderProps { task: ClineMessage tokensIn: number tokensOut: number @@ -44,12 +44,12 @@ const TaskHeader = ({ }: TaskHeaderProps) => { const { t } = useTranslation() const { apiConfiguration, currentTaskItem } = useExtensionState() - const { selectedModelInfo } = useMemo(() => normalizeApiConfiguration(apiConfiguration), [apiConfiguration]) + const { info: model } = useSelectedModel(apiConfiguration) const [isTaskExpanded, setIsTaskExpanded] = useState(false) const textContainerRef = useRef(null) const textRef = useRef(null) - const contextWindow = selectedModelInfo?.contextWindow || 1 + const contextWindow = model?.contextWindow || 1 const { width: windowWidth } = useWindowSize() @@ -96,7 +96,7 @@ const TaskHeader = ({ {!!totalCost && ${totalCost.toFixed(2)}} @@ -132,7 +132,7 @@ const TaskHeader = ({ )} diff --git a/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx b/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx index f1afd66b26..704c705d39 100644 --- a/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx +++ b/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx @@ -1,9 +1,13 @@ -import React from "react" +// npx jest src/components/chat/__tests__/ChatView.auto-approve.test.tsx + import { render, waitFor } from "@testing-library/react" -import ChatView from "../ChatView" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + import { ExtensionStateContextProvider } from "@src/context/ExtensionStateContext" import { vscode } from "@src/utils/vscode" +import ChatView, { ChatViewProps } from "../ChatView" + // Mock vscode API jest.mock("@src/utils/vscode", () => ({ vscode: { @@ -85,22 +89,32 @@ const mockPostMessage = (state: any) => { ) } +const queryClient = new QueryClient() + +const defaultProps: ChatViewProps = { + isHidden: false, + showAnnouncement: false, + hideAnnouncement: () => {}, + showHistoryView: () => {}, +} + +const renderChatView = (props: Partial = {}) => { + return render( + + + + + , + ) +} + describe("ChatView - Auto Approval Tests", () => { beforeEach(() => { jest.clearAllMocks() }) it("auto-approves read operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -147,16 +161,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves outside workspace read operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -235,16 +240,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve outside workspace read operations without permission", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -297,16 +293,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve when autoApprovalEnabled is false", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -351,16 +338,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves write operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -409,16 +387,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves outside workspace write operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -474,16 +443,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve outside workspace write operations without permission", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -539,16 +499,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves browser actions when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -595,16 +546,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves mode switch when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -651,16 +593,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve mode switch when disabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -705,16 +638,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve mode switch when auto-approval is disabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ diff --git a/webview-ui/src/components/chat/__tests__/ChatView.test.tsx b/webview-ui/src/components/chat/__tests__/ChatView.test.tsx index adc19ca0bb..13d9433f8d 100644 --- a/webview-ui/src/components/chat/__tests__/ChatView.test.tsx +++ b/webview-ui/src/components/chat/__tests__/ChatView.test.tsx @@ -1,9 +1,14 @@ +// npx jest src/components/chat/__tests__/ChatView.test.tsx + import React from "react" import { render, waitFor, act } from "@testing-library/react" -import ChatView from "../ChatView" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + import { ExtensionStateContextProvider } from "@src/context/ExtensionStateContext" import { vscode } from "@src/utils/vscode" +import ChatView, { ChatViewProps } from "../ChatView" + // Define minimal types needed for testing interface ClineMessage { type: "say" | "ask" @@ -85,13 +90,6 @@ jest.mock("../ChatTextArea", () => { } }) -jest.mock("../TaskHeader", () => ({ - __esModule: true, - default: function MockTaskHeader({ task }: { task: ClineMessage }) { - return
{JSON.stringify(task)}
- }, -})) - // Mock VSCode components jest.mock("@vscode/webview-ui-toolkit/react", () => ({ VSCodeButton: function MockVSCodeButton({ @@ -151,22 +149,30 @@ const mockPostMessage = (state: Partial) => { ) } +const defaultProps: ChatViewProps = { + isHidden: false, + showAnnouncement: false, + hideAnnouncement: () => {}, + showHistoryView: () => {}, +} + +const queryClient = new QueryClient() + +const renderChatView = (props: Partial = {}) => { + return render( + + + + + , + ) +} + describe("ChatView - Auto Approval Tests", () => { - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) it("does not auto-approve any actions when autoApprovalEnabled is false", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -240,16 +246,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves browser actions when alwaysAllowBrowser is enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -296,16 +293,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves read-only tools when alwaysAllowReadOnly is enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -353,16 +341,7 @@ describe("ChatView - Auto Approval Tests", () => { describe("Write Tool Auto-Approval Tests", () => { it("auto-approves write tools when alwaysAllowWrite is enabled and message is a tool request", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -411,16 +390,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve write operations when alwaysAllowWrite is enabled but message is not a tool request", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -466,16 +436,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves allowed commands when alwaysAllowExecute is enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -524,16 +485,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve disallowed commands even when alwaysAllowExecute is enabled", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -581,16 +533,7 @@ describe("ChatView - Auto Approval Tests", () => { describe("Command Chaining Tests", () => { it("auto-approves chained commands when all parts are allowed", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // Test various allowed command chaining scenarios const allowedChainedCommands = [ @@ -656,16 +599,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve chained commands when any part is disallowed", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // Test various command chaining scenarios with disallowed parts const disallowedChainedCommands = [ @@ -728,16 +662,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("handles complex PowerShell command chains correctly", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // Test PowerShell specific command chains const powershellCommands = { @@ -849,21 +774,10 @@ describe("ChatView - Auto Approval Tests", () => { }) describe("ChatView - Sound Playing Tests", () => { - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) it("does not play sound for auto-approved browser actions", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -915,16 +829,7 @@ describe("ChatView - Sound Playing Tests", () => { }) it("plays notification sound for non-auto-approved browser actions", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -978,16 +883,7 @@ describe("ChatView - Sound Playing Tests", () => { }) it("plays celebration sound for completion results", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -1037,16 +933,7 @@ describe("ChatView - Sound Playing Tests", () => { }) it("plays progress_loop sound for api failures", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -1097,9 +984,7 @@ describe("ChatView - Sound Playing Tests", () => { }) describe("ChatView - Focus Grabbing Tests", () => { - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) it("does not grab focus when follow-up question presented", async () => { const sleep = async (timeout: number) => { @@ -1108,16 +993,7 @@ describe("ChatView - Focus Grabbing Tests", () => { }) } - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ diff --git a/webview-ui/src/components/chat/__tests__/TaskHeader.test.tsx b/webview-ui/src/components/chat/__tests__/TaskHeader.test.tsx index dfd82a260f..625bad1169 100644 --- a/webview-ui/src/components/chat/__tests__/TaskHeader.test.tsx +++ b/webview-ui/src/components/chat/__tests__/TaskHeader.test.tsx @@ -2,9 +2,12 @@ import React from "react" import { render, screen } from "@testing-library/react" -import TaskHeader from "../TaskHeader" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + import { ApiConfiguration } from "@roo/shared/api" +import TaskHeader, { TaskHeaderProps } from "../TaskHeader" + // Mock the vscode API jest.mock("@/utils/vscode", () => ({ vscode: { @@ -30,8 +33,8 @@ jest.mock("@src/context/ExtensionStateContext", () => ({ })) describe("TaskHeader", () => { - const defaultProps = { - task: { text: "Test task", images: [] }, + const defaultProps: TaskHeaderProps = { + task: { type: "say", ts: Date.now(), text: "Test task", images: [] }, tokensIn: 100, tokensOut: 50, doesModelSupportPromptCache: true, @@ -40,82 +43,38 @@ describe("TaskHeader", () => { onClose: jest.fn(), } - it("should display cost when totalCost is greater than 0", () => { - render( - , + const queryClient = new QueryClient() + + const renderTaskHeader = (props: Partial = {}) => { + return render( + + + , ) + } + + it("should display cost when totalCost is greater than 0", () => { + renderTaskHeader() expect(screen.getByText("$0.05")).toBeInTheDocument() }) it("should not display cost when totalCost is 0", () => { - render( - , - ) + renderTaskHeader({ totalCost: 0 }) expect(screen.queryByText("$0.0000")).not.toBeInTheDocument() }) it("should not display cost when totalCost is null", () => { - render( - , - ) + renderTaskHeader({ totalCost: null as any }) expect(screen.queryByText(/\$/)).not.toBeInTheDocument() }) it("should not display cost when totalCost is undefined", () => { - render( - , - ) + renderTaskHeader({ totalCost: undefined as any }) expect(screen.queryByText(/\$/)).not.toBeInTheDocument() }) it("should not display cost when totalCost is NaN", () => { - render( - , - ) + renderTaskHeader({ totalCost: NaN }) expect(screen.queryByText(/\$/)).not.toBeInTheDocument() }) }) diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 052bc6e5fe..86e9defc5f 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -13,15 +13,11 @@ import { ModelInfo, azureOpenAiDefaultApiVersion, glamaDefaultModelId, - glamaDefaultModelInfo, mistralDefaultModelId, openAiModelInfoSaneDefaults, openRouterDefaultModelId, - openRouterDefaultModelInfo, unboundDefaultModelId, - unboundDefaultModelInfo, requestyDefaultModelId, - requestyDefaultModelInfo, ApiProvider, } from "@roo/shared/api" import { ExtensionMessage } from "@roo/shared/ExtensionMessage" @@ -29,7 +25,8 @@ import { AWS_REGIONS } from "@roo/shared/aws_regions" import { vscode } from "@src/utils/vscode" import { validateApiConfiguration, validateModelId, validateBedrockArn } from "@src/utils/validate" -import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration" +import { useRouterModels } from "@/components/ui/hooks/useRouterModels" +import { useSelectedModel } from "@/components/ui/hooks/useSelectedModel" import { useOpenRouterModelProviders, OPENROUTER_DEFAULT_PROVIDER_NAME, @@ -52,7 +49,7 @@ import { DiffSettingsControl } from "./DiffSettingsControl" import { TemperatureControl } from "./TemperatureControl" import { RateLimitSecondsControl } from "./RateLimitSecondsControl" -interface ApiOptionsProps { +export interface ApiOptionsProps { uriScheme: string | undefined apiConfiguration: ApiConfiguration setApiConfigurationField: (field: K, value: ApiConfiguration[K]) => void @@ -75,22 +72,6 @@ const ApiOptions = ({ const [lmStudioModels, setLmStudioModels] = useState([]) const [vsCodeLmModels, setVsCodeLmModels] = useState([]) - const [openRouterModels, setOpenRouterModels] = useState>({ - [openRouterDefaultModelId]: openRouterDefaultModelInfo, - }) - - const [glamaModels, setGlamaModels] = useState>({ - [glamaDefaultModelId]: glamaDefaultModelInfo, - }) - - const [unboundModels, setUnboundModels] = useState>({ - [unboundDefaultModelId]: unboundDefaultModelInfo, - }) - - const [requestyModels, setRequestyModels] = useState>({ - [requestyDefaultModelId]: requestyDefaultModelInfo, - }) - const [openAiModels, setOpenAiModels] = useState | null>(null) const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl) @@ -117,10 +98,13 @@ const ApiOptions = ({ [setApiConfigurationField], ) - const { selectedProvider, selectedModelId, selectedModelInfo } = useMemo( - () => normalizeApiConfiguration(apiConfiguration), - [apiConfiguration], - ) + const { + provider: selectedProvider, + id: selectedModelId, + info: selectedModelInfo, + } = useSelectedModel(apiConfiguration) + + const { data: routerModels } = useRouterModels() // Update apiConfiguration.aiModelId whenever selectedModelId changes. useEffect(() => { @@ -133,20 +117,9 @@ const ApiOptions = ({ // stops typing. useDebounce( () => { - if (selectedProvider === "openrouter") { - vscode.postMessage({ type: "refreshOpenRouterModels" }) - } else if (selectedProvider === "glama") { - vscode.postMessage({ type: "refreshGlamaModels" }) - } else if (selectedProvider === "unbound") { - vscode.postMessage({ type: "refreshUnboundModels" }) - } else if (selectedProvider === "requesty") { + if (selectedProvider === "openai") { vscode.postMessage({ - type: "refreshRequestyModels", - values: { apiKey: apiConfiguration?.requestyApiKey }, - }) - } else if (selectedProvider === "openai") { - vscode.postMessage({ - type: "refreshOpenAiModels", + type: "requestOpenAiModels", values: { baseUrl: apiConfiguration?.openAiBaseUrl, apiKey: apiConfiguration?.openAiApiKey, @@ -174,43 +147,23 @@ const ApiOptions = ({ useEffect(() => { const apiValidationResult = - validateApiConfiguration(apiConfiguration) || - validateModelId(apiConfiguration, glamaModels, openRouterModels, unboundModels, requestyModels) - + validateApiConfiguration(apiConfiguration) || validateModelId(apiConfiguration, routerModels) setErrorMessage(apiValidationResult) - }, [apiConfiguration, glamaModels, openRouterModels, setErrorMessage, unboundModels, requestyModels]) + }, [apiConfiguration, routerModels, setErrorMessage]) const { data: openRouterModelProviders } = useOpenRouterModelProviders(apiConfiguration?.openRouterModelId, { enabled: selectedProvider === "openrouter" && !!apiConfiguration?.openRouterModelId && - apiConfiguration.openRouterModelId in openRouterModels, + routerModels?.openrouter && + Object.keys(routerModels.openrouter).length > 1 && + apiConfiguration.openRouterModelId in routerModels.openrouter, }) const onMessage = useCallback((event: MessageEvent) => { const message: ExtensionMessage = event.data switch (message.type) { - case "openRouterModels": { - const updatedModels = message.openRouterModels ?? {} - setOpenRouterModels({ [openRouterDefaultModelId]: openRouterDefaultModelInfo, ...updatedModels }) - break - } - case "glamaModels": { - const updatedModels = message.glamaModels ?? {} - setGlamaModels({ [glamaDefaultModelId]: glamaDefaultModelInfo, ...updatedModels }) - break - } - case "unboundModels": { - const updatedModels = message.unboundModels ?? {} - setUnboundModels({ [unboundDefaultModelId]: unboundDefaultModelInfo, ...updatedModels }) - break - } - case "requestyModels": { - const updatedModels = message.requestyModels ?? {} - setRequestyModels({ [requestyDefaultModelId]: requestyDefaultModelInfo, ...updatedModels }) - break - } case "openAiModels": { const updatedModels = message.openAiModels ?? [] setOpenAiModels(Object.fromEntries(updatedModels.map((item) => [item, openAiModelInfoSaneDefaults]))) @@ -825,10 +778,8 @@ const ApiOptions = ({ apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} defaultModelId="gpt-4o" - defaultModelInfo={openAiModelInfoSaneDefaults} models={openAiModels} modelIdKey="openAiModelId" - modelInfoKey="openAiCustomModelInfo" serviceName="OpenAI" serviceUrl="https://platform.openai.com" /> @@ -1535,10 +1486,8 @@ const ApiOptions = ({ apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} defaultModelId={openRouterDefaultModelId} - defaultModelInfo={openRouterDefaultModelInfo} - models={openRouterModels} + models={routerModels?.openrouter ?? {}} modelIdKey="openRouterModelId" - modelInfoKey="openRouterModelInfo" serviceName="OpenRouter" serviceUrl="https://openrouter.ai/models" /> @@ -1558,16 +1507,7 @@ const ApiOptions = ({ , + VSCodeRadio: ({ value, checked }: any) => , VSCodeRadioGroup: ({ children }: any) =>
{children}
, VSCodeButton: ({ children }: any) =>
{children}
, })) @@ -54,6 +56,11 @@ jest.mock("@/components/ui", () => ({ {children} ), + Slider: ({ value, onChange }: any) => ( +
+ onChange(parseFloat(e.target.value))} /> +
+ ), })) jest.mock("../TemperatureControl", () => ({ @@ -86,16 +93,6 @@ jest.mock("../RateLimitSecondsControl", () => ({ ), })) -// Mock ThinkingBudget component -jest.mock("../ThinkingBudget", () => ({ - ThinkingBudget: ({ apiConfiguration, setApiConfigurationField, modelInfo, provider }: any) => - modelInfo?.thinking ? ( -
- -
- ) : null, -})) - // Mock DiffSettingsControl for tests jest.mock("../DiffSettingsControl", () => ({ DiffSettingsControl: ({ diffEnabled, fuzzyMatchThreshold, onChange }: any) => ( @@ -123,7 +120,23 @@ jest.mock("../DiffSettingsControl", () => ({ ), })) -const renderApiOptions = (props = {}) => { +jest.mock("@src/components/ui/hooks/useSelectedModel", () => ({ + useSelectedModel: jest.fn((apiConfiguration: ApiConfiguration) => { + if (apiConfiguration.apiModelId?.includes("thinking")) { + return { + provider: apiConfiguration.apiProvider, + info: { thinking: true, contextWindow: 4000, maxTokens: 128000 }, + } + } else { + return { + provider: apiConfiguration.apiProvider, + info: { contextWindow: 4000 }, + } + } + }), +})) + +const renderApiOptions = (props: Partial = {}) => { const queryClient = new QueryClient() render( @@ -192,7 +205,6 @@ describe("ApiOptions", () => { apiConfiguration: { apiProvider: "anthropic", apiModelId: "claude-3-opus-20240229", - modelInfo: { thinking: false }, // Non-thinking model }, }) diff --git a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx index 7c415e1ec4..710ca3e5ef 100644 --- a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx +++ b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx @@ -2,6 +2,9 @@ import { screen, fireEvent, render } from "@testing-library/react" import { act } from "react" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + +import { ModelInfo } from "@roo/schemas" import { ModelPicker } from "../ModelPicker" @@ -21,7 +24,8 @@ Element.prototype.scrollIntoView = jest.fn() describe("ModelPicker", () => { const mockSetApiConfigurationField = jest.fn() - const modelInfo = { + + const modelInfo: ModelInfo = { maxTokens: 8192, contextWindow: 200_000, supportsImages: true, @@ -32,16 +36,16 @@ describe("ModelPicker", () => { cacheWritesPrice: 3.75, cacheReadsPrice: 0.3, } + const mockModels = { model1: { name: "Model 1", description: "Test model 1", ...modelInfo }, model2: { name: "Model 2", description: "Test model 2", ...modelInfo }, } + const defaultProps = { apiConfiguration: {}, defaultModelId: "model1", - defaultModelInfo: modelInfo, modelIdKey: "glamaModelId" as const, - modelInfoKey: "glamaModelInfo" as const, serviceName: "Test Service", serviceUrl: "https://test.service", recommendedModel: "recommended-model", @@ -49,14 +53,22 @@ describe("ModelPicker", () => { setApiConfigurationField: mockSetApiConfigurationField, } + const queryClient = new QueryClient() + + const renderModelPicker = () => { + return render( + + + , + ) + } + beforeEach(() => { jest.clearAllMocks() }) it("calls setApiConfigurationField when a model is selected", async () => { - await act(async () => { - render() - }) + await act(async () => renderModelPicker()) await act(async () => { // Open the popover by clicking the button. @@ -84,13 +96,10 @@ describe("ModelPicker", () => { // Verify the API config was updated. expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, "model2") - expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelInfoKey, mockModels.model2) }) it("allows setting a custom model ID that's not in the predefined list", async () => { - await act(async () => { - render() - }) + await act(async () => renderModelPicker()) await act(async () => { // Open the popover by clicking the button. @@ -125,10 +134,5 @@ describe("ModelPicker", () => { // Verify the API config was updated with the custom model ID expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, customModelId) - // The model info should be set to the default since this is a custom model - expect(mockSetApiConfigurationField).toHaveBeenCalledWith( - defaultProps.modelInfoKey, - defaultProps.defaultModelInfo, - ) }) }) diff --git a/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx b/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx index 9b2fc37d25..81f3dea1fd 100644 --- a/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx +++ b/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx @@ -9,11 +9,7 @@ import { ExtensionStateContextProvider } from "@/context/ExtensionStateContext" import SettingsView from "../SettingsView" // Mock vscode API -jest.mock("@src/utils/vscode", () => ({ - vscode: { - postMessage: jest.fn(), - }, -})) +jest.mock("@src/utils/vscode", () => ({ vscode: { postMessage: jest.fn() } })) // Mock all lucide-react icons with a proxy to handle any icon requested jest.mock("lucide-react", () => { @@ -79,10 +75,10 @@ jest.mock("@vscode/webview-ui-toolkit/react", () => ({ /> ), VSCodeLink: ({ children, href }: any) => {children}, - VSCodeRadio: ({ children, value, checked, onChange }: any) => ( + VSCodeRadio: ({ value, checked, onChange }: any) => ( ), - VSCodeRadioGroup: ({ children, value, onChange }: any) =>
{children}
, + VSCodeRadioGroup: ({ children, onChange }: any) =>
{children}
, })) // Mock Slider component diff --git a/webview-ui/src/components/ui/hooks/useRouterModels.ts b/webview-ui/src/components/ui/hooks/useRouterModels.ts new file mode 100644 index 0000000000..56b2954db0 --- /dev/null +++ b/webview-ui/src/components/ui/hooks/useRouterModels.ts @@ -0,0 +1,38 @@ +import { RouterModels } from "@roo/shared/api" + +import { vscode } from "@src/utils/vscode" +import { ExtensionMessage } from "@roo/shared/ExtensionMessage" +import { useQuery } from "@tanstack/react-query" + +const getRouterModels = async () => + new Promise((resolve, reject) => { + const cleanup = () => { + window.removeEventListener("message", handler) + } + + const timeout = setTimeout(() => { + cleanup() + reject(new Error("Router models request timed out")) + }, 10000) + + const handler = (event: MessageEvent) => { + const message: ExtensionMessage = event.data + + if (message.type === "routerModels") { + clearTimeout(timeout) + cleanup() + + if (message.routerModels) { + console.log("message.routerModels", message.routerModels) + resolve(message.routerModels) + } else { + reject(new Error("No router models in response")) + } + } + } + + window.addEventListener("message", handler) + vscode.postMessage({ type: "requestRouterModels" }) + }) + +export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels }) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts new file mode 100644 index 0000000000..6b5b01b918 --- /dev/null +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -0,0 +1,124 @@ +import { + ApiConfiguration, + RouterModels, + ModelInfo, + anthropicDefaultModelId, + anthropicModels, + bedrockDefaultModelId, + bedrockModels, + deepSeekDefaultModelId, + deepSeekModels, + geminiDefaultModelId, + geminiModels, + mistralDefaultModelId, + mistralModels, + openAiModelInfoSaneDefaults, + openAiNativeDefaultModelId, + openAiNativeModels, + vertexDefaultModelId, + vertexModels, + xaiDefaultModelId, + xaiModels, + vscodeLlmModels, + vscodeLlmDefaultModelId, + openRouterDefaultModelId, + requestyDefaultModelId, + glamaDefaultModelId, + unboundDefaultModelId, +} from "@roo/shared/api" + +import { useRouterModels } from "./useRouterModels" + +export const useSelectedModel = (apiConfiguration?: ApiConfiguration) => { + const { data: routerModels, isLoading, isError } = useRouterModels() + const provider = apiConfiguration?.apiProvider || "anthropic" + const id = apiConfiguration ? getSelectedModelId({ provider, apiConfiguration }) : anthropicDefaultModelId + const info = routerModels ? getSelectedModelInfo({ provider, id, apiConfiguration, routerModels }) : undefined + return { provider, id, info, isLoading, isError } +} + +function getSelectedModelId({ provider, apiConfiguration }: { provider: string; apiConfiguration: ApiConfiguration }) { + switch (provider) { + case "openrouter": + return apiConfiguration.openRouterModelId ?? openRouterDefaultModelId + case "requesty": + return apiConfiguration.requestyModelId ?? requestyDefaultModelId + case "glama": + return apiConfiguration.glamaModelId ?? glamaDefaultModelId + case "unbound": + return apiConfiguration.unboundModelId ?? unboundDefaultModelId + case "openai": + return apiConfiguration.openAiModelId || "" + case "ollama": + return apiConfiguration.ollamaModelId || "" + case "lmstudio": + return apiConfiguration.lmStudioModelId || "" + case "vscode-lm": + return apiConfiguration?.vsCodeLmModelSelector + ? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}` + : "" + default: + return apiConfiguration.apiModelId ?? anthropicDefaultModelId + } +} + +function getSelectedModelInfo({ + provider, + id, + apiConfiguration, + routerModels, +}: { + provider: string + id: string + apiConfiguration?: ApiConfiguration + routerModels: RouterModels +}): ModelInfo { + switch (provider) { + case "openrouter": + return routerModels.openrouter[id] ?? routerModels.openrouter[openRouterDefaultModelId] + case "requesty": + return routerModels.requesty[id] ?? routerModels.requesty[requestyDefaultModelId] + case "glama": + return routerModels.glama[id] ?? routerModels.glama[glamaDefaultModelId] + case "unbound": + return routerModels.unbound[id] ?? routerModels.unbound[unboundDefaultModelId] + case "xai": + return xaiModels[id as keyof typeof xaiModels] ?? xaiModels[xaiDefaultModelId] + case "bedrock": + // Special case for custom ARN. + if (id === "custom-arn") { + return { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, supportsImages: true } + } + + return bedrockModels[id as keyof typeof bedrockModels] ?? bedrockModels[bedrockDefaultModelId] + case "vertex": + return vertexModels[id as keyof typeof vertexModels] ?? vertexModels[vertexDefaultModelId] + case "gemini": + return geminiModels[id as keyof typeof geminiModels] ?? geminiModels[geminiDefaultModelId] + case "deepseek": + return deepSeekModels[id as keyof typeof deepSeekModels] ?? deepSeekModels[deepSeekDefaultModelId] + case "openai-native": + return ( + openAiNativeModels[id as keyof typeof openAiNativeModels] ?? + openAiNativeModels[openAiNativeDefaultModelId] + ) + case "mistral": + return mistralModels[id as keyof typeof mistralModels] ?? mistralModels[mistralDefaultModelId] + case "openai": + return apiConfiguration?.openAiCustomModelInfo || openAiModelInfoSaneDefaults + case "ollama": + return openAiModelInfoSaneDefaults + case "lmstudio": + return openAiModelInfoSaneDefaults + case "vscode-lm": + const modelFamily = apiConfiguration?.vsCodeLmModelSelector?.family ?? vscodeLlmDefaultModelId + + return { + ...openAiModelInfoSaneDefaults, + ...vscodeLlmModels[modelFamily as keyof typeof vscodeLlmModels], + supportsImages: false, // VSCode LM API currently doesn't support images. + } + default: + return anthropicModels[id as keyof typeof anthropicModels] ?? anthropicModels[anthropicDefaultModelId] + } +} diff --git a/webview-ui/src/components/welcome/WelcomeView.tsx b/webview-ui/src/components/welcome/WelcomeView.tsx index 07871496a4..dd517f7b06 100644 --- a/webview-ui/src/components/welcome/WelcomeView.tsx +++ b/webview-ui/src/components/welcome/WelcomeView.tsx @@ -17,7 +17,7 @@ const WelcomeView = () => { const [errorMessage, setErrorMessage] = useState(undefined) const handleSubmit = useCallback(() => { - const error = validateApiConfiguration(apiConfiguration) + const error = apiConfiguration ? validateApiConfiguration(apiConfiguration) : undefined if (error) { setErrorMessage(error) diff --git a/webview-ui/src/stories/Chat.stories.tsx b/webview-ui/src/stories/Chat.stories.tsx index d9227bc437..55e5345c2e 100644 --- a/webview-ui/src/stories/Chat.stories.tsx +++ b/webview-ui/src/stories/Chat.stories.tsx @@ -37,7 +37,7 @@ const useStorybookChat = (): ChatHandler => { const [input, setInput] = useState("") const [messages, setMessages] = useState([]) - const append = async (message: Message, options?: { data?: any }) => { + const append = async (message: Message, _options?: { data?: any }) => { const echo: Message = { ...message, role: "assistant", diff --git a/webview-ui/src/utils/normalizeApiConfiguration.ts b/webview-ui/src/utils/normalizeApiConfiguration.ts deleted file mode 100644 index 908636ee99..0000000000 --- a/webview-ui/src/utils/normalizeApiConfiguration.ts +++ /dev/null @@ -1,141 +0,0 @@ -import { - ApiConfiguration, - ModelInfo, - anthropicDefaultModelId, - anthropicModels, - bedrockDefaultModelId, - bedrockModels, - deepSeekDefaultModelId, - deepSeekModels, - geminiDefaultModelId, - geminiModels, - glamaDefaultModelId, - glamaDefaultModelInfo, - mistralDefaultModelId, - mistralModels, - openAiModelInfoSaneDefaults, - openAiNativeDefaultModelId, - openAiNativeModels, - openRouterDefaultModelId, - openRouterDefaultModelInfo, - vertexDefaultModelId, - vertexModels, - unboundDefaultModelId, - unboundDefaultModelInfo, - requestyDefaultModelId, - requestyDefaultModelInfo, - xaiDefaultModelId, - xaiModels, - vscodeLlmModels, - vscodeLlmDefaultModelId, -} from "@roo/shared/api" - -export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) { - const provider = apiConfiguration?.apiProvider || "anthropic" - const modelId = apiConfiguration?.apiModelId - - const getProviderData = (models: Record, defaultId: string) => { - let selectedModelId: string - let selectedModelInfo: ModelInfo - - if (modelId && modelId in models) { - selectedModelId = modelId - selectedModelInfo = models[modelId] - } else { - selectedModelId = defaultId - selectedModelInfo = models[defaultId] - } - - return { selectedProvider: provider, selectedModelId, selectedModelInfo } - } - - switch (provider) { - case "anthropic": - return getProviderData(anthropicModels, anthropicDefaultModelId) - case "xai": - return getProviderData(xaiModels, xaiDefaultModelId) - case "bedrock": - // Special case for custom ARN - if (modelId === "custom-arn") { - return { - selectedProvider: provider, - selectedModelId: "custom-arn", - selectedModelInfo: { - maxTokens: 5000, - contextWindow: 128_000, - supportsPromptCache: false, - supportsImages: true, - }, - } - } - return getProviderData(bedrockModels, bedrockDefaultModelId) - case "vertex": - return getProviderData(vertexModels, vertexDefaultModelId) - case "gemini": - return getProviderData(geminiModels, geminiDefaultModelId) - case "deepseek": - return getProviderData(deepSeekModels, deepSeekDefaultModelId) - case "openai-native": - return getProviderData(openAiNativeModels, openAiNativeDefaultModelId) - case "mistral": - return getProviderData(mistralModels, mistralDefaultModelId) - case "openrouter": - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId, - selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo, - } - case "glama": - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.glamaModelId || glamaDefaultModelId, - selectedModelInfo: apiConfiguration?.glamaModelInfo || glamaDefaultModelInfo, - } - case "unbound": - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.unboundModelId || unboundDefaultModelId, - selectedModelInfo: apiConfiguration?.unboundModelInfo || unboundDefaultModelInfo, - } - case "requesty": - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.requestyModelId || requestyDefaultModelId, - selectedModelInfo: apiConfiguration?.requestyModelInfo || requestyDefaultModelInfo, - } - case "openai": - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.openAiModelId || "", - selectedModelInfo: apiConfiguration?.openAiCustomModelInfo || openAiModelInfoSaneDefaults, - } - case "ollama": - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.ollamaModelId || "", - selectedModelInfo: openAiModelInfoSaneDefaults, - } - case "lmstudio": - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.lmStudioModelId || "", - selectedModelInfo: openAiModelInfoSaneDefaults, - } - case "vscode-lm": - const modelFamily = apiConfiguration?.vsCodeLmModelSelector?.family ?? vscodeLlmDefaultModelId - const modelInfo = { - ...openAiModelInfoSaneDefaults, - ...vscodeLlmModels[modelFamily as keyof typeof vscodeLlmModels], - supportsImages: false, // VSCode LM API currently doesn't support images. - } - return { - selectedProvider: provider, - selectedModelId: apiConfiguration?.vsCodeLmModelSelector - ? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}` - : "", - selectedModelInfo: modelInfo, - } - default: - return getProviderData(anthropicModels, anthropicDefaultModelId) - } -} diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 439839a347..5c072edc0d 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -1,11 +1,8 @@ -import { ApiConfiguration, ModelInfo } from "@roo/shared/api" import i18next from "i18next" -export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): string | undefined { - if (!apiConfiguration) { - return undefined - } +import { ApiConfiguration, isRouterName, RouterModels } from "@roo/shared/api" +export function validateApiConfiguration(apiConfiguration: ApiConfiguration): string | undefined { switch (apiConfiguration.apiProvider) { case "openrouter": if (!apiConfiguration.openRouterApiKey) { @@ -113,92 +110,41 @@ export function validateBedrockArn(arn: string, region?: string) { } // ARN is valid and region matches (or no region was provided to check against) - return { - isValid: true, - arnRegion, - errorMessage: undefined, - } + return { isValid: true, arnRegion, errorMessage: undefined } } -export function validateModelId( - apiConfiguration?: ApiConfiguration, - glamaModels?: Record, - openRouterModels?: Record, - unboundModels?: Record, - requestyModels?: Record, -): string | undefined { - if (!apiConfiguration) { +export function validateModelId(apiConfiguration: ApiConfiguration, routerModels?: RouterModels): string | undefined { + const provider = apiConfiguration.apiProvider ?? "" + + if (!isRouterName(provider)) { return undefined } - switch (apiConfiguration.apiProvider) { - case "openrouter": - const modelId = apiConfiguration.openRouterModelId - - if (!modelId) { - return i18next.t("settings:validation.modelId") - } - - if ( - openRouterModels && - Object.keys(openRouterModels).length > 1 && - !Object.keys(openRouterModels).includes(modelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId }) - } + let modelId: string | undefined + switch (provider) { + case "openrouter": + modelId = apiConfiguration.openRouterModelId break - case "glama": - const glamaModelId = apiConfiguration.glamaModelId - - if (!glamaModelId) { - return i18next.t("settings:validation.modelId") - } - - if ( - glamaModels && - Object.keys(glamaModels).length > 1 && - !Object.keys(glamaModels).includes(glamaModelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId: glamaModelId }) - } - + modelId = apiConfiguration.glamaModelId break - case "unbound": - const unboundModelId = apiConfiguration.unboundModelId - - if (!unboundModelId) { - return i18next.t("settings:validation.modelId") - } - - if ( - unboundModels && - Object.keys(unboundModels).length > 1 && - !Object.keys(unboundModels).includes(unboundModelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId: unboundModelId }) - } - + modelId = apiConfiguration.unboundModelId break - case "requesty": - const requestyModelId = apiConfiguration.requestyModelId + modelId = apiConfiguration.requestyModelId + break + } - if (!requestyModelId) { - return i18next.t("settings:validation.modelId") - } + if (!modelId) { + return i18next.t("settings:validation.modelId") + } - if ( - requestyModels && - Object.keys(requestyModels).length > 1 && - !Object.keys(requestyModels).includes(requestyModelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId: requestyModelId }) - } + const models = routerModels?.[provider] - break + if (models && Object.keys(models).length > 1 && !Object.keys(models).includes(modelId)) { + return i18next.t("settings:validation.modelAvailability", { modelId }) } return undefined