From 5d40802f7aa6ae331da2ccb2b4d536bb0861bcf0 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 00:51:26 -0700 Subject: [PATCH 01/15] Remove ModelInfo objects from settings --- e2e/src/suite/index.ts | 9 - src/activate/registerCommands.ts | 10 +- src/api/providers/__tests__/glama.test.ts | 6 - .../providers/__tests__/openrouter.test.ts | 11 - src/api/providers/__tests__/requesty.test.ts | 40 ++-- src/api/providers/__tests__/unbound.test.ts | 24 --- src/api/providers/fetchers/cache.ts | 84 ++++++++ src/api/providers/fetchers/glama.ts | 42 ++++ src/api/providers/fetchers/index.ts | 4 + src/api/providers/fetchers/openrouter.ts | 2 +- src/api/providers/fetchers/requesty.ts | 41 ++++ src/api/providers/fetchers/unbound.ts | 46 ++++ src/api/providers/glama.ts | 149 ++++--------- src/api/providers/openai.ts | 1 - src/api/providers/openrouter.ts | 37 +++- src/api/providers/requesty.ts | 82 ++------ src/api/providers/router-provider.ts | 47 +++++ src/api/providers/unbound.ts | 171 ++++----------- src/core/__tests__/Cline.test.ts | 8 +- src/core/config/ContextProxy.ts | 21 ++ src/core/webview/ClineProvider.ts | 35 +-- .../webview/__tests__/ClineProvider.test.ts | 42 +--- src/core/webview/webviewMessageHandler.ts | 160 +++----------- src/exports/roo-code.d.ts | 120 ----------- src/exports/types.ts | 120 ----------- src/extension.ts | 4 +- src/schemas/index.ts | 8 - src/shared/ExtensionMessage.ts | 29 +-- src/shared/WebviewMessage.ts | 9 +- src/shared/globalFileNames.ts | 4 - webview-ui/src/components/chat/ChatView.tsx | 10 +- webview-ui/src/components/chat/TaskHeader.tsx | 12 +- .../src/components/settings/ApiOptions.tsx | 109 ++-------- .../src/components/settings/ModelInfoView.tsx | 20 +- .../src/components/settings/ModelPicker.tsx | 20 +- .../settings/__tests__/ModelPicker.test.tsx | 7 +- .../src/utils/normalizeApiConfiguration.ts | 199 ++++++++++-------- webview-ui/src/utils/validate.ts | 29 ++- 38 files changed, 664 insertions(+), 1108 deletions(-) create mode 100644 src/api/providers/fetchers/cache.ts create mode 100644 src/api/providers/fetchers/glama.ts create mode 100644 src/api/providers/fetchers/index.ts create mode 100644 src/api/providers/fetchers/requesty.ts create mode 100644 src/api/providers/fetchers/unbound.ts create mode 100644 src/api/providers/router-provider.ts diff --git a/e2e/src/suite/index.ts b/e2e/src/suite/index.ts index d371a0f4c8..1a3e265662 100644 --- a/e2e/src/suite/index.ts +++ b/e2e/src/suite/index.ts @@ -24,15 +24,6 @@ export async function run() { apiProvider: "openrouter" as const, openRouterApiKey: process.env.OPENROUTER_API_KEY!, openRouterModelId: "google/gemini-2.0-flash-001", - openRouterModelInfo: { - maxTokens: 8192, - contextWindow: 1000000, - supportsImages: true, - supportsPromptCache: false, - inputPrice: 0.1, - outputPrice: 0.4, - thinking: false, - }, }) await vscode.commands.executeCommand("roo-cline.SidebarProvider.focus") diff --git a/src/activate/registerCommands.ts b/src/activate/registerCommands.ts index 1883083b6e..c00cd9cd2f 100644 --- a/src/activate/registerCommands.ts +++ b/src/activate/registerCommands.ts @@ -2,6 +2,10 @@ import * as vscode from "vscode" import delay from "delay" import { ClineProvider } from "../core/webview/ClineProvider" +import { ContextProxy } from "../core/config/ContextProxy" + +import { registerHumanRelayCallback, unregisterHumanRelayCallback, handleHumanRelayResponse } from "./humanRelay" +import { handleNewTask } from "./handleTask" /** * Helper to get the visible ClineProvider instance or log if not found. @@ -15,9 +19,6 @@ export function getVisibleProviderOrLog(outputChannel: vscode.OutputChannel): Cl return visibleProvider } -import { registerHumanRelayCallback, unregisterHumanRelayCallback, handleHumanRelayResponse } from "./humanRelay" -import { handleNewTask } from "./handleTask" - // Store panel references in both modes let sidebarPanel: vscode.WebviewView | undefined = undefined let tabPanel: vscode.WebviewPanel | undefined = undefined @@ -142,7 +143,8 @@ export const openClineInNewTab = async ({ context, outputChannel }: Omit editor.viewColumn || 0)) // Check if there are any visible text editors, otherwise open a new group diff --git a/src/api/providers/__tests__/glama.test.ts b/src/api/providers/__tests__/glama.test.ts index 5e017ccd0a..5870d5c28f 100644 --- a/src/api/providers/__tests__/glama.test.ts +++ b/src/api/providers/__tests__/glama.test.ts @@ -207,12 +207,6 @@ describe("GlamaHandler", () => { apiModelId: "openai/gpt-4", glamaModelId: "openai/gpt-4", glamaApiKey: "test-key", - glamaModelInfo: { - maxTokens: 4096, - contextWindow: 8192, - supportsImages: true, - supportsPromptCache: false, - }, } const nonAnthropicHandler = new GlamaHandler(nonAnthropicOptions) diff --git a/src/api/providers/__tests__/openrouter.test.ts b/src/api/providers/__tests__/openrouter.test.ts index 92bc46249a..6677ec8d72 100644 --- a/src/api/providers/__tests__/openrouter.test.ts +++ b/src/api/providers/__tests__/openrouter.test.ts @@ -24,7 +24,6 @@ describe("OpenRouterHandler", () => { const mockOptions: ApiHandlerOptions = { openRouterApiKey: "test-key", openRouterModelId: "test-model", - openRouterModelInfo: mockOpenRouterModelInfo, } beforeEach(() => { @@ -52,7 +51,6 @@ describe("OpenRouterHandler", () => { expect(result).toEqual({ id: mockOptions.openRouterModelId, - info: mockOptions.openRouterModelInfo, maxTokens: 1000, thinking: undefined, temperature: 0, @@ -77,11 +75,6 @@ describe("OpenRouterHandler", () => { const handler = new OpenRouterHandler({ openRouterApiKey: "test-key", openRouterModelId: "test-model", - openRouterModelInfo: { - ...mockOpenRouterModelInfo, - maxTokens: 128_000, - thinking: true, - }, modelMaxTokens: 32_768, modelMaxThinkingTokens: 16_384, }) @@ -188,10 +181,6 @@ describe("OpenRouterHandler", () => { it("adds cache control for supported models", async () => { const handler = new OpenRouterHandler({ ...mockOptions, - openRouterModelInfo: { - ...mockOpenRouterModelInfo, - supportsPromptCache: true, - }, openRouterModelId: "anthropic/claude-3.5-sonnet", }) diff --git a/src/api/providers/__tests__/requesty.test.ts b/src/api/providers/__tests__/requesty.test.ts index 2b3da4a7ad..53dda2637e 100644 --- a/src/api/providers/__tests__/requesty.test.ts +++ b/src/api/providers/__tests__/requesty.test.ts @@ -14,22 +14,23 @@ describe("RequestyHandler", () => { let handler: RequestyHandler let mockCreate: jest.Mock + const modelInfo: ModelInfo = { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsComputerUse: true, + supportsPromptCache: true, + inputPrice: 3.0, + outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, + description: + "Claude 3.7 Sonnet is an advanced large language model with improved reasoning, coding, and problem-solving capabilities. It introduces a hybrid reasoning approach, allowing users to choose between rapid responses and extended, step-by-step processing for complex tasks. The model demonstrates notable improvements in coding, particularly in front-end development and full-stack updates, and excels in agentic workflows, where it can autonomously navigate multi-step processes. Claude 3.7 Sonnet maintains performance parity with its predecessor in standard mode while offering an extended reasoning mode for enhanced accuracy in math, coding, and instruction-following tasks. Read more at the [blog post here](https://www.anthropic.com/news/claude-3-7-sonnet)", + } + const defaultOptions: ApiHandlerOptions = { requestyApiKey: "test-key", requestyModelId: "test-model", - requestyModelInfo: { - maxTokens: 8192, - contextWindow: 200_000, - supportsImages: true, - supportsComputerUse: true, - supportsPromptCache: true, - inputPrice: 3.0, - outputPrice: 15.0, - cacheWritesPrice: 3.75, - cacheReadsPrice: 0.3, - description: - "Claude 3.7 Sonnet is an advanced large language model with improved reasoning, coding, and problem-solving capabilities. It introduces a hybrid reasoning approach, allowing users to choose between rapid responses and extended, step-by-step processing for complex tasks. The model demonstrates notable improvements in coding, particularly in front-end development and full-stack updates, and excels in agentic workflows, where it can autonomously navigate multi-step processes. Claude 3.7 Sonnet maintains performance parity with its predecessor in standard mode while offering an extended reasoning mode for enhanced accuracy in math, coding, and instruction-following tasks. Read more at the [blog post here](https://www.anthropic.com/news/claude-3-7-sonnet)", - }, openAiStreamingEnabled: true, includeMaxTokens: true, // Add this to match the implementation } @@ -185,7 +186,7 @@ describe("RequestyHandler", () => { ], stream: true, stream_options: { include_usage: true }, - max_tokens: defaultOptions.requestyModelInfo?.maxTokens, + max_tokens: modelInfo.maxTokens, }) }) @@ -279,20 +280,17 @@ describe("RequestyHandler", () => { const result = handler.getModel() expect(result).toEqual({ id: defaultOptions.requestyModelId, - info: defaultOptions.requestyModelInfo, + info: modelInfo, }) }) it("should use sane defaults when no model info provided", () => { - handler = new RequestyHandler({ - ...defaultOptions, - requestyModelInfo: undefined, - }) - + handler = new RequestyHandler(defaultOptions) const result = handler.getModel() + expect(result).toEqual({ id: defaultOptions.requestyModelId, - info: defaultOptions.requestyModelInfo, + info: modelInfo, }) }) }) diff --git a/src/api/providers/__tests__/unbound.test.ts b/src/api/providers/__tests__/unbound.test.ts index 5c54c24e8d..06e05448ce 100644 --- a/src/api/providers/__tests__/unbound.test.ts +++ b/src/api/providers/__tests__/unbound.test.ts @@ -74,14 +74,6 @@ describe("UnboundHandler", () => { apiModelId: "anthropic/claude-3-5-sonnet-20241022", unboundApiKey: "test-api-key", unboundModelId: "anthropic/claude-3-5-sonnet-20241022", - unboundModelInfo: { - description: "Anthropic's Claude 3 Sonnet model", - maxTokens: 8192, - contextWindow: 200000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.02, - }, } handler = new UnboundHandler(mockOptions) mockCreate.mockClear() @@ -220,14 +212,6 @@ describe("UnboundHandler", () => { apiModelId: "openai/gpt-4o", unboundApiKey: "test-key", unboundModelId: "openai/gpt-4o", - unboundModelInfo: { - description: "OpenAI's GPT-4", - maxTokens: undefined, - contextWindow: 128000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.03, - }, } const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions) @@ -254,13 +238,6 @@ describe("UnboundHandler", () => { apiModelId: "openai/o3-mini", unboundApiKey: "test-key", unboundModelId: "openai/o3-mini", - unboundModelInfo: { - maxTokens: undefined, - contextWindow: 128000, - supportsPromptCache: true, - inputPrice: 0.01, - outputPrice: 0.03, - }, } const openaiHandler = new UnboundHandler(openaiOptions) @@ -291,7 +268,6 @@ describe("UnboundHandler", () => { const handlerWithInvalidModel = new UnboundHandler({ ...mockOptions, unboundModelId: "invalid/model", - unboundModelInfo: undefined, }) const modelInfo = handlerWithInvalidModel.getModel() expect(modelInfo.id).toBe("anthropic/claude-3-5-sonnet-20241022") // Default model diff --git a/src/api/providers/fetchers/cache.ts b/src/api/providers/fetchers/cache.ts new file mode 100644 index 0000000000..890ee91a1a --- /dev/null +++ b/src/api/providers/fetchers/cache.ts @@ -0,0 +1,84 @@ +import * as path from "path" +import fs from "fs/promises" + +import NodeCache from "node-cache" + +import { ContextProxy } from "../../../core/config/ContextProxy" +import { getCacheDirectoryPath } from "../../../shared/storagePathManager" +import { fileExistsAtPath } from "../../../utils/fs" +import type { ModelInfo } from "../../../schemas" +import { getOpenRouterModels } from "./openrouter" +import { getRequestyModels } from "./requesty" +import { getGlamaModels } from "./glama" +import { getUnboundModels } from "./unbound" + +export type RouterName = "openrouter" | "requesty" | "glama" | "unbound" + +export type ModelRecord = Record + +const memoryCache = new NodeCache({ + stdTTL: 5 * 60, + checkperiod: 5 * 60, +}) + +async function writeModels(router: RouterName, data: ModelRecord) { + const filename = `${router}_models.json` + const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) + await fs.writeFile(path.join(cacheDir, filename), JSON.stringify(data)) +} + +async function readModels(router: RouterName): Promise { + const filename = `${router}_models.json` + const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) + const filePath = path.join(cacheDir, filename) + const exists = await fileExistsAtPath(filePath) + return exists ? JSON.parse(await fs.readFile(filePath, "utf8")) : undefined +} + +/** + * Get models from the cache or fetch them from the provider and cache them. + * There are two caches: + * 1. Memory cache - This is a simple in-memory cache that is used to store models for a short period of time. + * 2. File cache - This is a file-based cache that is used to store models for a longer period of time. + * + * @param router - The router to fetch models from. + * @returns The models from the cache or the fetched models. + */ +export const getModels = async (router: RouterName): Promise => { + let models = memoryCache.get(router) + + if (models) { + return models + } + + switch (router) { + case "openrouter": + models = await getOpenRouterModels() + break + case "requesty": + models = await getRequestyModels() + break + case "glama": + models = await getGlamaModels() + break + case "unbound": + models = await getUnboundModels() + break + } + + if (Object.keys(models).length > 0) { + memoryCache.set(router, models) + + try { + await writeModels(router, models) + } catch (error) {} + + return models + } + + try { + models = await readModels(router) + } catch (error) {} + + return models ?? {} +} diff --git a/src/api/providers/fetchers/glama.ts b/src/api/providers/fetchers/glama.ts new file mode 100644 index 0000000000..82ceba5233 --- /dev/null +++ b/src/api/providers/fetchers/glama.ts @@ -0,0 +1,42 @@ +import axios from "axios" + +import { ModelInfo } from "../../../shared/api" +import { parseApiPrice } from "../../../utils/cost" + +export async function getGlamaModels(): Promise> { + const models: Record = {} + + try { + const response = await axios.get("https://glama.ai/api/gateway/v1/models") + const rawModels = response.data + + for (const rawModel of rawModels) { + const modelInfo: ModelInfo = { + maxTokens: rawModel.maxTokensOutput, + contextWindow: rawModel.maxTokensInput, + supportsImages: rawModel.capabilities?.includes("input:image"), + supportsComputerUse: rawModel.capabilities?.includes("computer_use"), + supportsPromptCache: rawModel.capabilities?.includes("caching"), + inputPrice: parseApiPrice(rawModel.pricePerToken?.input), + outputPrice: parseApiPrice(rawModel.pricePerToken?.output), + description: undefined, + cacheWritesPrice: parseApiPrice(rawModel.pricePerToken?.cacheWrite), + cacheReadsPrice: parseApiPrice(rawModel.pricePerToken?.cacheRead), + } + + switch (rawModel.id) { + case rawModel.id.startsWith("anthropic/"): + modelInfo.maxTokens = 8192 + break + default: + break + } + + models[rawModel.id] = modelInfo + } + } catch (error) { + console.error(`Error fetching Glama models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + return models +} diff --git a/src/api/providers/fetchers/index.ts b/src/api/providers/fetchers/index.ts new file mode 100644 index 0000000000..24644c6542 --- /dev/null +++ b/src/api/providers/fetchers/index.ts @@ -0,0 +1,4 @@ +export { getOpenRouterModels } from "./openrouter" +export { getRequestyModels } from "./requesty" +export { getGlamaModels } from "./glama" +export { getUnboundModels } from "./unbound" diff --git a/src/api/providers/fetchers/openrouter.ts b/src/api/providers/fetchers/openrouter.ts index 99dc41faf9..db0ac5a0ca 100644 --- a/src/api/providers/fetchers/openrouter.ts +++ b/src/api/providers/fetchers/openrouter.ts @@ -46,7 +46,7 @@ const openRouterModelsResponseSchema = z.object({ type OpenRouterModelsResponse = z.infer -export async function getOpenRouterModels(options?: ApiHandlerOptions) { +export async function getOpenRouterModels(options?: ApiHandlerOptions): Promise> { const models: Record = {} const baseURL = options?.openRouterBaseUrl || "https://openrouter.ai/api/v1" diff --git a/src/api/providers/fetchers/requesty.ts b/src/api/providers/fetchers/requesty.ts new file mode 100644 index 0000000000..7fe6e41a2b --- /dev/null +++ b/src/api/providers/fetchers/requesty.ts @@ -0,0 +1,41 @@ +import axios from "axios" + +import { ModelInfo } from "../../../shared/api" +import { parseApiPrice } from "../../../utils/cost" + +export async function getRequestyModels(apiKey?: string): Promise> { + const models: Record = {} + + try { + const headers: Record = {} + + if (apiKey) { + headers["Authorization"] = `Bearer ${apiKey}` + } + + const url = "https://router.requesty.ai/v1/models" + const response = await axios.get(url, { headers }) + const rawModels = response.data.data + + for (const rawModel of rawModels) { + const modelInfo: ModelInfo = { + maxTokens: rawModel.max_output_tokens, + contextWindow: rawModel.context_window, + supportsPromptCache: rawModel.supports_caching, + supportsImages: rawModel.supports_vision, + supportsComputerUse: rawModel.supports_computer_use, + inputPrice: parseApiPrice(rawModel.input_price), + outputPrice: parseApiPrice(rawModel.output_price), + description: rawModel.description, + cacheWritesPrice: parseApiPrice(rawModel.caching_price), + cacheReadsPrice: parseApiPrice(rawModel.cached_price), + } + + models[rawModel.id] = modelInfo + } + } catch (error) { + console.error(`Error fetching Requesty models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + return models +} diff --git a/src/api/providers/fetchers/unbound.ts b/src/api/providers/fetchers/unbound.ts new file mode 100644 index 0000000000..73a8c2f897 --- /dev/null +++ b/src/api/providers/fetchers/unbound.ts @@ -0,0 +1,46 @@ +import axios from "axios" + +import { ModelInfo } from "../../../shared/api" + +export async function getUnboundModels(): Promise> { + const models: Record = {} + + try { + const response = await axios.get("https://api.getunbound.ai/models") + + if (response.data) { + const rawModels: Record = response.data + + for (const [modelId, model] of Object.entries(rawModels)) { + const modelInfo: ModelInfo = { + maxTokens: model?.maxTokens ? parseInt(model.maxTokens) : undefined, + contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0, + supportsImages: model?.supportsImages ?? false, + supportsPromptCache: model?.supportsPromptCaching ?? false, + supportsComputerUse: model?.supportsComputerUse ?? false, + inputPrice: model?.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined, + outputPrice: model?.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined, + cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined, + cacheReadsPrice: model?.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined, + } + + switch (true) { + case modelId.startsWith("anthropic/"): + // Set max tokens to 8192 for supported Anthropic models + if (modelInfo.maxTokens !== 4096) { + modelInfo.maxTokens = 8192 + } + break + default: + break + } + + models[modelId] = modelInfo + } + } + } catch (error) { + console.error(`Error fetching Unbound models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + return models +} diff --git a/src/api/providers/glama.ts b/src/api/providers/glama.ts index 43b6ebfb7a..3e010ed920 100644 --- a/src/api/providers/glama.ts +++ b/src/api/providers/glama.ts @@ -2,118 +2,87 @@ import { Anthropic } from "@anthropic-ai/sdk" import axios from "axios" import OpenAI from "openai" -import { ApiHandlerOptions, ModelInfo, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api" -import { parseApiPrice } from "../../utils/cost" +import { ApiHandlerOptions, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" -import { SingleCompletionHandler } from "../" -import { BaseProvider } from "./base-provider" +import { SingleCompletionHandler } from "../index" +import { RouterProvider } from "./router-provider" const GLAMA_DEFAULT_TEMPERATURE = 0 -export class GlamaHandler extends BaseProvider implements SingleCompletionHandler { - protected options: ApiHandlerOptions - private client: OpenAI +const DEFAULT_HEADERS = { + "X-Glama-Metadata": JSON.stringify({ labels: [{ key: "app", value: "vscode.rooveterinaryinc.roo-cline" }] }), +} +export class GlamaHandler extends RouterProvider implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super() - this.options = options - const baseURL = "https://glama.ai/api/gateway/openai/v1" - const apiKey = this.options.glamaApiKey ?? "not-provided" - this.client = new OpenAI({ baseURL, apiKey }) - } - - private supportsTemperature(): boolean { - return !this.getModel().id.startsWith("openai/o3-mini") - } - - override getModel(): { id: string; info: ModelInfo } { - const modelId = this.options.glamaModelId - const modelInfo = this.options.glamaModelInfo - - if (modelId && modelInfo) { - return { id: modelId, info: modelInfo } - } - - return { id: glamaDefaultModelId, info: glamaDefaultModelInfo } + super({ + options, + name: "unbound", + baseURL: "https://glama.ai/api/gateway/openai/v1", + apiKey: options.glamaApiKey, + modelId: options.glamaModelId ?? glamaDefaultModelId, + defaultModelInfo: glamaDefaultModelInfo, + }) } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - // Convert Anthropic messages to OpenAI format + const { id: modelId, info } = await this.fetchModel() + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages), ] - // this is specifically for claude models (some models may 'support prompt caching' automatically without this) - if (this.getModel().id.startsWith("anthropic/claude-3")) { + if (modelId.startsWith("anthropic/claude-3")) { openAiMessages[0] = { role: "system", - content: [ - { - type: "text", - text: systemPrompt, - // @ts-ignore-next-line - cache_control: { type: "ephemeral" }, - }, - ], + // @ts-ignore-next-line + content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }], } - // Add cache_control to the last two user messages - // (note: this works because we only ever add one user message at a time, - // but if we added multiple we'd need to mark the user message before the last assistant message) const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) + lastTwoUserMessages.forEach((msg) => { if (typeof msg.content === "string") { msg.content = [{ type: "text", text: msg.content }] } + if (Array.isArray(msg.content)) { - // NOTE: this is fine since env details will always be added at the end. - // but if it weren't there, and the user added a image_url type message, - // it would pop a text part before it and then move it after to the end. let lastTextPart = msg.content.filter((part) => part.type === "text").pop() if (!lastTextPart) { lastTextPart = { type: "text", text: "..." } msg.content.push(lastTextPart) } + // @ts-ignore-next-line lastTextPart["cache_control"] = { type: "ephemeral" } } }) } - // Required by Anthropic - // Other providers default to max tokens allowed. + // Required by Anthropic; other providers default to max tokens allowed. let maxTokens: number | undefined - if (this.getModel().id.startsWith("anthropic/")) { - maxTokens = this.getModel().info.maxTokens ?? undefined + if (modelId.startsWith("anthropic/")) { + maxTokens = info.maxTokens ?? undefined } const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { - model: this.getModel().id, + model: modelId, max_tokens: maxTokens, messages: openAiMessages, stream: true, } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? GLAMA_DEFAULT_TEMPERATURE } const { data: completion, response } = await this.client.chat.completions .create(requestOptions, { - headers: { - "X-Glama-Metadata": JSON.stringify({ - labels: [ - { - key: "app", - value: "vscode.rooveterinaryinc.roo-cline", - }, - ], - }), - }, + headers: DEFAULT_HEADERS, }) .withResponse() @@ -123,10 +92,7 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle const delta = chunk.choices[0]?.delta if (delta?.content) { - yield { - type: "text", - text: delta.content, - } + yield { type: "text", text: delta.content } } } @@ -140,11 +106,7 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle // before we can fetch information about the token usage and cost. const response = await axios.get( `https://glama.ai/api/gateway/v1/completion-requests/${completionRequestId}`, - { - headers: { - Authorization: `Bearer ${this.options.glamaApiKey}`, - }, - }, + { headers: { Authorization: `Bearer ${this.options.glamaApiKey}` } }, ) const completionRequest = response.data @@ -170,18 +132,20 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle } async completePrompt(prompt: string): Promise { + const { id: modelId, info } = await this.fetchModel() + try { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: this.getModel().id, + model: modelId, messages: [{ role: "user", content: prompt }], } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? GLAMA_DEFAULT_TEMPERATURE } - if (this.getModel().id.startsWith("anthropic/")) { - requestOptions.max_tokens = this.getModel().info.maxTokens + if (modelId.startsWith("anthropic/")) { + requestOptions.max_tokens = info.maxTokens } const response = await this.client.chat.completions.create(requestOptions) @@ -190,45 +154,8 @@ export class GlamaHandler extends BaseProvider implements SingleCompletionHandle if (error instanceof Error) { throw new Error(`Glama completion error: ${error.message}`) } - throw error - } - } -} - -export async function getGlamaModels() { - const models: Record = {} - - try { - const response = await axios.get("https://glama.ai/api/gateway/v1/models") - const rawModels = response.data - - for (const rawModel of rawModels) { - const modelInfo: ModelInfo = { - maxTokens: rawModel.maxTokensOutput, - contextWindow: rawModel.maxTokensInput, - supportsImages: rawModel.capabilities?.includes("input:image"), - supportsComputerUse: rawModel.capabilities?.includes("computer_use"), - supportsPromptCache: rawModel.capabilities?.includes("caching"), - inputPrice: parseApiPrice(rawModel.pricePerToken?.input), - outputPrice: parseApiPrice(rawModel.pricePerToken?.output), - description: undefined, - cacheWritesPrice: parseApiPrice(rawModel.pricePerToken?.cacheWrite), - cacheReadsPrice: parseApiPrice(rawModel.pricePerToken?.cacheRead), - } - switch (rawModel.id) { - case rawModel.id.startsWith("anthropic/"): - modelInfo.maxTokens = 8192 - break - default: - break - } - - models[rawModel.id] = modelInfo + throw error } - } catch (error) { - console.error(`Error fetching Glama models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) } - - return models } diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 4b6a7982ca..71568dfde1 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -74,7 +74,6 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const enabledR1Format = this.options.openAiR1FormatEnabled ?? false const enabledLegacyFormat = this.options.openAiLegacyFormat ?? false const isAzureAiInference = this._isAzureAiInference(modelUrl) - const urlHost = this._getUrlHost(modelUrl) const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format const ark = modelUrl.includes(".volces.com") diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index a9162495ee..fde6f5313f 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -16,6 +16,7 @@ import { convertToR1Format } from "../transform/r1-format" import { getModelParams, SingleCompletionHandler } from "../index" import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants" import { BaseProvider } from "./base-provider" +import { ModelRecord, getModels } from "./fetchers/cache" const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]" @@ -51,6 +52,7 @@ interface CompletionUsage { export class OpenRouterHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: OpenAI + protected models: ModelRecord = {} constructor(options: ApiHandlerOptions) { super() @@ -66,7 +68,15 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH systemPrompt: string, messages: Anthropic.Messages.MessageParam[], ): AsyncGenerator { - let { id: modelId, maxTokens, thinking, temperature, topP, reasoningEffort, promptCache } = this.getModel() + let { + id: modelId, + maxTokens, + thinking, + temperature, + topP, + reasoningEffort, + promptCache, + } = await this.fetchModel() // Convert Anthropic messages to OpenAI format. let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ @@ -183,22 +193,27 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } } + private async fetchModel() { + this.models = await getModels("openrouter") + return this.getModel() + } + override getModel() { - const modelId = this.options.openRouterModelId - const modelInfo = this.options.openRouterModelInfo + const id = this.options.openRouterModelId ?? openRouterDefaultModelId + const info = this.models[id] ?? openRouterDefaultModelInfo - let id = modelId ?? openRouterDefaultModelId - const info = modelInfo ?? openRouterDefaultModelInfo - const isDeepSeekR1 = id.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning" - const defaultTemperature = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0 - const topP = isDeepSeekR1 ? 0.95 : undefined + const isDeepSeekR1 = id.startsWith("deepseek/deepseek-r1") || id === "perplexity/sonar-reasoning" return { id, info, // maxTokens, thinking, temperature, reasoningEffort - ...getModelParams({ options: this.options, model: info, defaultTemperature }), - topP, + ...getModelParams({ + options: this.options, + model: info, + defaultTemperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0, + }), + topP: isDeepSeekR1 ? 0.95 : undefined, promptCache: { supported: PROMPT_CACHING_MODELS.has(id), optional: OPTIONAL_PROMPT_CACHING_MODELS.has(id), @@ -207,7 +222,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } async completePrompt(prompt: string) { - let { id: modelId, maxTokens, thinking, temperature } = this.getModel() + let { id: modelId, maxTokens, thinking, temperature } = await this.fetchModel() const completionParams: OpenRouterChatCompletionParams = { model: modelId, diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index 680b6e7179..e13ca2e23d 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -1,10 +1,11 @@ -import axios from "axios" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" -import { ModelInfo, requestyDefaultModelInfo, requestyDefaultModelId } from "../../shared/api" -import { calculateApiCostOpenAI, parseApiPrice } from "../../utils/cost" -import { ApiStreamUsageChunk } from "../transform/stream" +import { ModelInfo, requestyDefaultModelId, requestyDefaultModelInfo } from "../../shared/api" +import { calculateApiCostOpenAI } from "../../utils/cost" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { OpenAiHandler, OpenAiHandlerOptions } from "./openai" -import OpenAI from "openai" +import { ModelRecord, getModels } from "./fetchers/cache" // Requesty usage includes an extra field for Anthropic use cases. // Safely cast the prompt token details section to the appropriate structure. @@ -17,25 +18,30 @@ interface RequestyUsage extends OpenAI.CompletionUsage { } export class RequestyHandler extends OpenAiHandler { + protected models: ModelRecord = {} + constructor(options: OpenAiHandlerOptions) { if (!options.requestyApiKey) { throw new Error("Requesty API key is required. Please provide it in the settings.") } + super({ ...options, openAiApiKey: options.requestyApiKey, openAiModelId: options.requestyModelId ?? requestyDefaultModelId, openAiBaseUrl: "https://router.requesty.ai/v1", - openAiCustomModelInfo: options.requestyModelInfo ?? requestyDefaultModelInfo, }) } + override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + this.models = await getModels("requesty") + return super.createMessage(systemPrompt, messages) + } + override getModel(): { id: string; info: ModelInfo } { - const modelId = this.options.requestyModelId ?? requestyDefaultModelId - return { - id: modelId, - info: this.options.requestyModelInfo ?? requestyDefaultModelInfo, - } + const id = this.options.requestyModelId ?? requestyDefaultModelId + const info = this.models[id] ?? requestyDefaultModelInfo + return { id, info } } protected override processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk { @@ -47,6 +53,7 @@ export class RequestyHandler extends OpenAiHandler { const totalCost = modelInfo ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) : 0 + return { type: "usage", inputTokens: inputTokens, @@ -56,56 +63,9 @@ export class RequestyHandler extends OpenAiHandler { totalCost: totalCost, } } -} - -export async function getRequestyModels(apiKey?: string) { - const models: Record = {} - - try { - const headers: Record = {} - if (apiKey) { - headers["Authorization"] = `Bearer ${apiKey}` - } - - const url = "https://router.requesty.ai/v1/models" - const response = await axios.get(url, { headers }) - const rawModels = response.data.data - for (const rawModel of rawModels) { - // { - // id: "anthropic/claude-3-5-sonnet-20240620", - // object: "model", - // created: 1740552655, - // owned_by: "system", - // input_price: 0.0000028, - // caching_price: 0.00000375, - // cached_price: 3e-7, - // output_price: 0.000015, - // max_output_tokens: 8192, - // context_window: 200000, - // supports_caching: true, - // description: - // "Anthropic's previous most intelligent model. High level of intelligence and capability. Excells in coding.", - // } - - const modelInfo: ModelInfo = { - maxTokens: rawModel.max_output_tokens, - contextWindow: rawModel.context_window, - supportsPromptCache: rawModel.supports_caching, - supportsImages: rawModel.supports_vision, - supportsComputerUse: rawModel.supports_computer_use, - inputPrice: parseApiPrice(rawModel.input_price), - outputPrice: parseApiPrice(rawModel.output_price), - description: rawModel.description, - cacheWritesPrice: parseApiPrice(rawModel.caching_price), - cacheReadsPrice: parseApiPrice(rawModel.cached_price), - } - - models[rawModel.id] = modelInfo - } - } catch (error) { - console.error(`Error fetching Requesty models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + override async completePrompt(prompt: string): Promise { + this.models = await getModels("requesty") + return super.completePrompt(prompt) } - - return models } diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts new file mode 100644 index 0000000000..d951ead592 --- /dev/null +++ b/src/api/providers/router-provider.ts @@ -0,0 +1,47 @@ +import OpenAI from "openai" + +import { ApiHandlerOptions, ModelInfo } from "../../shared/api" +import { BaseProvider } from "./base-provider" +import { RouterName, ModelRecord, getModels } from "./fetchers/cache" + +type RouterProviderOptions = { + name: RouterName + baseURL: string + apiKey?: string + modelId: string + defaultModelInfo: ModelInfo + options: ApiHandlerOptions +} + +export abstract class RouterProvider extends BaseProvider { + protected readonly options: ApiHandlerOptions + protected readonly name: RouterName + protected models: ModelRecord = {} + protected readonly modelId: string + protected readonly defaultModelInfo: ModelInfo + protected readonly client: OpenAI + + constructor({ options, name, baseURL, apiKey = "not-provided", modelId, defaultModelInfo }: RouterProviderOptions) { + super() + + this.options = options + this.name = name + this.modelId = modelId + this.defaultModelInfo = defaultModelInfo + + this.client = new OpenAI({ baseURL, apiKey }) + } + + protected async fetchModel() { + this.models = await getModels(this.name) + return this.getModel() + } + + override getModel(): { id: string; info: ModelInfo } { + return { id: this.modelId, info: this.models[this.modelId] ?? this.defaultModelInfo } + } + + protected supportsTemperature(modelId: string): boolean { + return !modelId.startsWith("openai/o3-mini") + } +} diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index 0413c96f29..9c808c4ad9 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -1,111 +1,89 @@ import { Anthropic } from "@anthropic-ai/sdk" -import axios from "axios" import OpenAI from "openai" -import { ApiHandlerOptions, ModelInfo, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api" +import { ApiHandlerOptions, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" -import { SingleCompletionHandler } from "../" -import { BaseProvider } from "./base-provider" +import { SingleCompletionHandler } from "../index" +import { RouterProvider } from "./router-provider" + +const DEFAULT_HEADERS = { + "X-Unbound-Metadata": JSON.stringify({ labels: [{ key: "app", value: "roo-code" }] }), +} interface UnboundUsage extends OpenAI.CompletionUsage { cache_creation_input_tokens?: number cache_read_input_tokens?: number } -export class UnboundHandler extends BaseProvider implements SingleCompletionHandler { - protected options: ApiHandlerOptions - private client: OpenAI - +export class UnboundHandler extends RouterProvider implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super() - this.options = options - const baseURL = "https://api.getunbound.ai/v1" - const apiKey = this.options.unboundApiKey ?? "not-provided" - this.client = new OpenAI({ baseURL, apiKey }) - } - - private supportsTemperature(): boolean { - return !this.getModel().id.startsWith("openai/o3-mini") + super({ + options, + name: "unbound", + baseURL: "https://api.getunbound.ai/v1", + apiKey: options.unboundApiKey, + modelId: options.unboundModelId ?? unboundDefaultModelId, + defaultModelInfo: unboundDefaultModelInfo, + }) } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - // Convert Anthropic messages to OpenAI format + const { id: modelId, info } = await this.fetchModel() + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages), ] - // this is specifically for claude models (some models may 'support prompt caching' automatically without this) - if (this.getModel().id.startsWith("anthropic/claude-3")) { + if (modelId.startsWith("anthropic/claude-3")) { openAiMessages[0] = { role: "system", - content: [ - { - type: "text", - text: systemPrompt, - // @ts-ignore-next-line - cache_control: { type: "ephemeral" }, - }, - ], + // @ts-ignore-next-line + content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }], } - // Add cache_control to the last two user messages - // (note: this works because we only ever add one user message at a time, - // but if we added multiple we'd need to mark the user message before the last assistant message) const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) + lastTwoUserMessages.forEach((msg) => { if (typeof msg.content === "string") { msg.content = [{ type: "text", text: msg.content }] } + if (Array.isArray(msg.content)) { - // NOTE: this is fine since env details will always be added at the end. - // but if it weren't there, and the user added a image_url type message, - // it would pop a text part before it and then move it after to the end. let lastTextPart = msg.content.filter((part) => part.type === "text").pop() if (!lastTextPart) { lastTextPart = { type: "text", text: "..." } msg.content.push(lastTextPart) } + // @ts-ignore-next-line lastTextPart["cache_control"] = { type: "ephemeral" } } }) } - // Required by Anthropic - // Other providers default to max tokens allowed. + // Required by Anthropic; other providers default to max tokens allowed. let maxTokens: number | undefined - if (this.getModel().id.startsWith("anthropic/")) { - maxTokens = this.getModel().info.maxTokens ?? undefined + if (modelId.startsWith("anthropic/")) { + maxTokens = info.maxTokens ?? undefined } const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: this.getModel().id.split("/")[1], + model: modelId.split("/")[1], max_tokens: maxTokens, messages: openAiMessages, stream: true, } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? 0 } - const { data: completion, response } = await this.client.chat.completions - .create(requestOptions, { - headers: { - "X-Unbound-Metadata": JSON.stringify({ - labels: [ - { - key: "app", - value: "roo-code", - }, - ], - }), - }, - }) + const { data: completion } = await this.client.chat.completions + .create(requestOptions, { headers: DEFAULT_HEADERS }) .withResponse() for await (const chunk of completion) { @@ -113,10 +91,7 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand const usage = chunk.usage as UnboundUsage if (delta?.content) { - yield { - type: "text", - text: delta.content, - } + yield { type: "text", text: delta.content } } if (usage) { @@ -126,10 +101,11 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand outputTokens: usage.completion_tokens || 0, } - // Only add cache tokens if they exist + // Only add cache tokens if they exist. if (usage.cache_creation_input_tokens) { usageData.cacheWriteTokens = usage.cache_creation_input_tokens } + if (usage.cache_read_input_tokens) { usageData.cacheReadTokens = usage.cache_read_input_tokens } @@ -139,94 +115,31 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand } } - override getModel(): { id: string; info: ModelInfo } { - const modelId = this.options.unboundModelId - const modelInfo = this.options.unboundModelInfo - if (modelId && modelInfo) { - return { id: modelId, info: modelInfo } - } - return { - id: unboundDefaultModelId, - info: unboundDefaultModelInfo, - } - } - async completePrompt(prompt: string): Promise { + const { id: modelId, info } = await this.fetchModel() + try { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: this.getModel().id.split("/")[1], + model: modelId.split("/")[1], messages: [{ role: "user", content: prompt }], } - if (this.supportsTemperature()) { + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? 0 } - if (this.getModel().id.startsWith("anthropic/")) { - requestOptions.max_tokens = this.getModel().info.maxTokens + if (modelId.startsWith("anthropic/")) { + requestOptions.max_tokens = info.maxTokens } - const response = await this.client.chat.completions.create(requestOptions, { - headers: { - "X-Unbound-Metadata": JSON.stringify({ - labels: [ - { - key: "app", - value: "roo-code", - }, - ], - }), - }, - }) + const response = await this.client.chat.completions.create(requestOptions, { headers: DEFAULT_HEADERS }) return response.choices[0]?.message.content || "" } catch (error) { if (error instanceof Error) { throw new Error(`Unbound completion error: ${error.message}`) } - throw error - } - } -} -export async function getUnboundModels() { - const models: Record = {} - - try { - const response = await axios.get("https://api.getunbound.ai/models") - - if (response.data) { - const rawModels: Record = response.data - - for (const [modelId, model] of Object.entries(rawModels)) { - const modelInfo: ModelInfo = { - maxTokens: model?.maxTokens ? parseInt(model.maxTokens) : undefined, - contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0, - supportsImages: model?.supportsImages ?? false, - supportsPromptCache: model?.supportsPromptCaching ?? false, - supportsComputerUse: model?.supportsComputerUse ?? false, - inputPrice: model?.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined, - outputPrice: model?.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined, - cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined, - cacheReadsPrice: model?.cacheReadPrice ? parseFloat(model.cacheReadPrice) : undefined, - } - - switch (true) { - case modelId.startsWith("anthropic/"): - // Set max tokens to 8192 for supported Anthropic models - if (modelInfo.maxTokens !== 4096) { - modelInfo.maxTokens = 8192 - } - break - default: - break - } - - models[modelId] = modelInfo - } + throw error } - } catch (error) { - console.error(`Error fetching Unbound models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) } - - return models } diff --git a/src/core/__tests__/Cline.test.ts b/src/core/__tests__/Cline.test.ts index cbad4b1f95..e12078d352 100644 --- a/src/core/__tests__/Cline.test.ts +++ b/src/core/__tests__/Cline.test.ts @@ -11,6 +11,7 @@ import { Cline } from "../Cline" import { ClineProvider } from "../webview/ClineProvider" import { ApiConfiguration, ModelInfo } from "../../shared/api" import { ApiStreamChunk } from "../../api/transform/stream" +import { ContextProxy } from "../config/ContextProxy" // Mock RooIgnoreController jest.mock("../ignore/RooIgnoreController") @@ -225,7 +226,12 @@ describe("Cline", () => { } // Setup mock provider with output channel - mockProvider = new ClineProvider(mockExtensionContext, mockOutputChannel) as jest.Mocked + mockProvider = new ClineProvider( + mockExtensionContext, + mockOutputChannel, + "sidebar", + new ContextProxy(mockExtensionContext), + ) as jest.Mocked // Setup mock API configuration mockApiConfig = { diff --git a/src/core/config/ContextProxy.ts b/src/core/config/ContextProxy.ts index aa40477ad8..dbab107d5b 100644 --- a/src/core/config/ContextProxy.ts +++ b/src/core/config/ContextProxy.ts @@ -256,4 +256,25 @@ export class ContextProxy { await this.initialize() } + + private static _instance: ContextProxy | null = null + + static get instance() { + if (!this._instance) { + throw new Error("ContextProxy not initialized") + } + + return this._instance + } + + static async getInstance(context: vscode.ExtensionContext) { + if (this._instance) { + return this._instance + } + + this._instance = new ContextProxy(context) + await this._instance.initialize() + + return this._instance + } } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index bf5901b817..8adf55ea99 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -15,7 +15,6 @@ import { setPanel } from "../../activate/registerCommands" import { ApiConfiguration, ApiProvider, - ModelInfo, requestyDefaultModelId, requestyDefaultModelInfo, openRouterDefaultModelId, @@ -80,7 +79,6 @@ export class ClineProvider extends EventEmitter implements public isViewLaunched = false public settingsImportedAt?: number public readonly latestAnnouncementId = "apr-23-2025-3-14" // Update for v3.14.0 announcement - public readonly contextProxy: ContextProxy public readonly providerSettingsManager: ProviderSettingsManager public readonly customModesManager: CustomModesManager @@ -88,11 +86,11 @@ export class ClineProvider extends EventEmitter implements readonly context: vscode.ExtensionContext, private readonly outputChannel: vscode.OutputChannel, private readonly renderContext: "sidebar" | "editor" = "sidebar", + public readonly contextProxy: ContextProxy, ) { super() this.log("ClineProvider instantiated") - this.contextProxy = new ContextProxy(context) ClineProvider.activeInstances.add(this) // Register this provider with the telemetry service to enable it to add @@ -340,11 +338,6 @@ export class ClineProvider extends EventEmitter implements async resolveWebviewView(webviewView: vscode.WebviewView | vscode.WebviewPanel) { this.log("Resolving webview view") - - if (!this.contextProxy.isInitialized) { - await this.contextProxy.initialize() - } - this.view = webviewView // Set panel reference according to webview type @@ -939,29 +932,6 @@ export class ClineProvider extends EventEmitter implements return getSettingsDirectoryPath(globalStoragePath) } - private async ensureCacheDirectoryExists() { - const { getCacheDirectoryPath } = await import("../../shared/storagePathManager") - const globalStoragePath = this.contextProxy.globalStorageUri.fsPath - return getCacheDirectoryPath(globalStoragePath) - } - - async writeModelsToCache(filename: string, data: T) { - const cacheDir = await this.ensureCacheDirectoryExists() - await fs.writeFile(path.join(cacheDir, filename), JSON.stringify(data)) - } - - async readModelsFromCache(filename: string): Promise | undefined> { - const filePath = path.join(await this.ensureCacheDirectoryExists(), filename) - const fileExists = await fileExistsAtPath(filePath) - - if (fileExists) { - const fileContents = await fs.readFile(filePath, "utf8") - return JSON.parse(fileContents) - } - - return undefined - } - // OpenRouter async handleOpenRouterCallback(code: string) { @@ -990,7 +960,6 @@ export class ClineProvider extends EventEmitter implements apiProvider: "openrouter", openRouterApiKey: apiKey, openRouterModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId, - openRouterModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo, } await this.upsertApiConfiguration(currentApiConfigName, newConfiguration) @@ -1021,7 +990,6 @@ export class ClineProvider extends EventEmitter implements apiProvider: "glama", glamaApiKey: apiKey, glamaModelId: apiConfiguration?.glamaModelId || glamaDefaultModelId, - glamaModelInfo: apiConfiguration?.glamaModelInfo || glamaDefaultModelInfo, } await this.upsertApiConfiguration(currentApiConfigName, newConfiguration) @@ -1037,7 +1005,6 @@ export class ClineProvider extends EventEmitter implements apiProvider: "requesty", requestyApiKey: code, requestyModelId: apiConfiguration?.requestyModelId || requestyDefaultModelId, - requestyModelInfo: apiConfiguration?.requestyModelInfo || requestyDefaultModelInfo, } await this.upsertApiConfiguration(currentApiConfigName, newConfiguration) diff --git a/src/core/webview/__tests__/ClineProvider.test.ts b/src/core/webview/__tests__/ClineProvider.test.ts index 5d6067bbf5..4c5a28fa4d 100644 --- a/src/core/webview/__tests__/ClineProvider.test.ts +++ b/src/core/webview/__tests__/ClineProvider.test.ts @@ -9,6 +9,7 @@ import { setSoundEnabled } from "../../../utils/sound" import { setTtsEnabled } from "../../../utils/tts" import { defaultModeSlug } from "../../../shared/modes" import { experimentDefault } from "../../../shared/experiments" +import { ContextProxy } from "../../config/ContextProxy" // Mock setup must come before imports jest.mock("../../prompts/sections/custom-instructions") @@ -307,6 +308,7 @@ describe("ClineProvider", () => { // Mock webview mockPostMessage = jest.fn() + mockWebviewView = { webview: { postMessage: mockPostMessage, @@ -325,7 +327,7 @@ describe("ClineProvider", () => { }), } as unknown as vscode.WebviewView - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) // @ts-ignore - Access private property for testing updateGlobalStateSpy = jest.spyOn(provider.contextProxy, "setValue") @@ -357,6 +359,8 @@ describe("ClineProvider", () => { provider = new ClineProvider( { ...mockContext, extensionMode: vscode.ExtensionMode.Development }, mockOutputChannel, + "sidebar", + new ContextProxy(mockContext), ) ;(axios.get as jest.Mock).mockRejectedValueOnce(new Error("Network error")) @@ -810,7 +814,6 @@ describe("ClineProvider", () => { const modeCustomInstructions = "Code mode instructions" const mockApiConfig = { apiProvider: "openrouter", - openRouterModelInfo: { supportsComputerUse: true }, } jest.spyOn(provider, "getState").mockResolvedValue({ @@ -906,7 +909,7 @@ describe("ClineProvider", () => { } as unknown as vscode.ExtensionContext // Create new provider with updated mock context - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) await provider.resolveWebviewView(mockWebviewView) const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] @@ -1069,16 +1072,6 @@ describe("ClineProvider", () => { jest.spyOn(provider, "getState").mockResolvedValue({ apiConfiguration: { apiProvider: "openrouter" as const, - openRouterModelInfo: { - supportsComputerUse: true, - supportsPromptCache: false, - maxTokens: 4096, - contextWindow: 8192, - supportsImages: false, - inputPrice: 0.0, - outputPrice: 0.0, - description: undefined, - }, }, mcpEnabled: true, enableMcpServerCreation: false, @@ -1102,16 +1095,6 @@ describe("ClineProvider", () => { jest.spyOn(provider, "getState").mockResolvedValue({ apiConfiguration: { apiProvider: "openrouter" as const, - openRouterModelInfo: { - supportsComputerUse: true, - supportsPromptCache: false, - maxTokens: 4096, - contextWindow: 8192, - supportsImages: false, - inputPrice: 0.0, - outputPrice: 0.0, - description: undefined, - }, }, mcpEnabled: false, enableMcpServerCreation: false, @@ -1184,7 +1167,6 @@ describe("ClineProvider", () => { apiConfiguration: { apiProvider: "openrouter", apiModelId: "test-model", - openRouterModelInfo: { supportsComputerUse: true }, }, customModePrompts: {}, mode: "code", @@ -1241,7 +1223,6 @@ describe("ClineProvider", () => { apiConfiguration: { apiProvider: "openrouter", apiModelId: "test-model", - openRouterModelInfo: { supportsComputerUse: true }, }, customModePrompts: {}, mode: "code", @@ -1282,7 +1263,6 @@ describe("ClineProvider", () => { jest.spyOn(provider, "getState").mockResolvedValue({ apiConfiguration: { apiProvider: "openrouter", - openRouterModelInfo: { supportsComputerUse: true }, }, customModePrompts: { architect: { customInstructions: "Architect mode instructions" }, @@ -1973,7 +1953,7 @@ describe("Project MCP Settings", () => { onDidChangeVisibility: jest.fn(), } as unknown as vscode.WebviewView - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) }) test("handles openProjectMcpSettings message", async () => { @@ -2068,10 +2048,8 @@ describe.skip("ContextProxy integration", () => { } as unknown as vscode.ExtensionContext mockOutputChannel = { appendLine: jest.fn() } as unknown as vscode.OutputChannel - provider = new ClineProvider(mockContext, mockOutputChannel) - - // @ts-ignore - accessing private property for testing - mockContextProxy = provider.contextProxy + mockContextProxy = new ContextProxy(mockContext) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", mockContextProxy) mockGlobalStateUpdate = mockContext.globalState.update as jest.Mock }) @@ -2131,7 +2109,7 @@ describe("getTelemetryProperties", () => { } as unknown as vscode.ExtensionContext mockOutputChannel = { appendLine: jest.fn() } as unknown as vscode.OutputChannel - provider = new ClineProvider(mockContext, mockOutputChannel) + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) // Setup Cline instance with mocked getModel method const { Cline } = require("../../Cline") diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index a62c1d9410..a7cfb20c6f 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -8,7 +8,6 @@ import { Language, ApiConfigMeta } from "../../schemas" import { changeLanguage, t } from "../../i18n" import { ApiConfiguration } from "../../shared/api" import { supportPrompt } from "../../shared/support-prompt" -import { GlobalFileNames } from "../../shared/globalFileNames" import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage" import { checkExistKey } from "../../shared/checkExistApiConfig" @@ -25,10 +24,6 @@ import { playTts, setTtsEnabled, setTtsSpeed, stopTts } from "../../utils/tts" import { singleCompletionHandler } from "../../utils/single-completion-handler" import { searchCommits } from "../../utils/git" import { exportSettings, importSettings } from "../config/importExport" -import { getOpenRouterModels } from "../../api/providers/fetchers/openrouter" -import { getGlamaModels } from "../../api/providers/glama" -import { getUnboundModels } from "../../api/providers/unbound" -import { getRequestyModels } from "../../api/providers/requesty" import { getOpenAiModels } from "../../api/providers/openai" import { getOllamaModels } from "../../api/providers/ollama" import { getVsCodeLmModels } from "../../api/providers/vscode-lm" @@ -42,6 +37,7 @@ import { SYSTEM_PROMPT } from "../prompts/system" import { buildApiHandler } from "../../api" import { GlobalState } from "../../schemas" import { MultiSearchReplaceDiffStrategy } from "../diff/strategies/multi-search-replace" +import { getModels } from "../../api/providers/fetchers/cache" export const webviewMessageHandler = async (provider: ClineProvider, message: WebviewMessage) => { // Utility functions provided for concise get/update of global state via contextProxy API. @@ -56,104 +52,18 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We await updateGlobalState("customModes", customModes) provider.postStateToWebview() - provider.workspaceTracker?.initializeFilePaths() // don't await + provider.workspaceTracker?.initializeFilePaths() // Don't await. getTheme().then((theme) => provider.postMessageToWebview({ type: "theme", text: JSON.stringify(theme) })) - // If MCP Hub is already initialized, update the webview with current server list + // If MCP Hub is already initialized, update the webview with + // current server list. const mcpHub = provider.getMcpHub() + if (mcpHub) { - provider.postMessageToWebview({ - type: "mcpServers", - mcpServers: mcpHub.getAllServers(), - }) + provider.postMessageToWebview({ type: "mcpServers", mcpServers: mcpHub.getAllServers() }) } - // Post last cached models in case the call to endpoint fails. - provider.readModelsFromCache(GlobalFileNames.openRouterModels).then((openRouterModels) => { - if (openRouterModels) { - provider.postMessageToWebview({ type: "openRouterModels", openRouterModels }) - } - }) - - // GUI relies on model info to be up-to-date to provide - // the most accurate pricing, so we need to fetch the - // latest details on launch. - // We do this for all users since many users switch - // between api providers and if they were to switch back - // to OpenRouter it would be showing outdated model info - // if we hadn't retrieved the latest at this point - // (see normalizeApiConfiguration > openrouter). - const { apiConfiguration: currentApiConfig } = await provider.getState() - - getOpenRouterModels(currentApiConfig).then(async (openRouterModels) => { - if (Object.keys(openRouterModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.openRouterModels, openRouterModels) - await provider.postMessageToWebview({ type: "openRouterModels", openRouterModels }) - - // Update model info in state (this needs to be - // done here since we don't want to update state - // while settings is open, and we may refresh - // models there). - const { apiConfiguration } = await provider.getState() - - if (apiConfiguration.openRouterModelId) { - await updateGlobalState( - "openRouterModelInfo", - openRouterModels[apiConfiguration.openRouterModelId], - ) - - await provider.postStateToWebview() - } - } - }) - - provider.readModelsFromCache(GlobalFileNames.glamaModels).then((glamaModels) => { - if (glamaModels) { - provider.postMessageToWebview({ type: "glamaModels", glamaModels }) - } - }) - - getGlamaModels().then(async (glamaModels) => { - if (Object.keys(glamaModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.glamaModels, glamaModels) - await provider.postMessageToWebview({ type: "glamaModels", glamaModels }) - - const { apiConfiguration } = await provider.getState() - - if (apiConfiguration.glamaModelId) { - await updateGlobalState("glamaModelInfo", glamaModels[apiConfiguration.glamaModelId]) - await provider.postStateToWebview() - } - } - }) - - provider.readModelsFromCache(GlobalFileNames.unboundModels).then((unboundModels) => { - if (unboundModels) { - provider.postMessageToWebview({ type: "unboundModels", unboundModels }) - } - }) - - getUnboundModels().then(async (unboundModels) => { - if (Object.keys(unboundModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.unboundModels, unboundModels) - await provider.postMessageToWebview({ type: "unboundModels", unboundModels }) - - const { apiConfiguration } = await provider.getState() - - if (apiConfiguration?.unboundModelId) { - await updateGlobalState("unboundModelInfo", unboundModels[apiConfiguration.unboundModelId]) - await provider.postStateToWebview() - } - } - }) - - provider.readModelsFromCache(GlobalFileNames.requestyModels).then((requestyModels) => { - if (requestyModels) { - provider.postMessageToWebview({ type: "requestyModels", requestyModels }) - } - }) - provider.providerSettingsManager .listConfig() .then(async (listApiConfig) => { @@ -371,51 +281,32 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We case "resetState": await provider.resetState() break - case "refreshOpenRouterModels": { - const { apiConfiguration: configForRefresh } = await provider.getState() - const openRouterModels = await getOpenRouterModels(configForRefresh) - - if (Object.keys(openRouterModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.openRouterModels, openRouterModels) - await provider.postMessageToWebview({ type: "openRouterModels", openRouterModels }) - } - - break - } - case "refreshGlamaModels": - const glamaModels = await getGlamaModels() - - if (Object.keys(glamaModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.glamaModels, glamaModels) - await provider.postMessageToWebview({ type: "glamaModels", glamaModels }) - } - - break - case "refreshUnboundModels": - const unboundModels = await getUnboundModels() - - if (Object.keys(unboundModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.unboundModels, unboundModels) - await provider.postMessageToWebview({ type: "unboundModels", unboundModels }) - } - - break - case "refreshRequestyModels": - const requestyModels = await getRequestyModels(message.values?.apiKey) - - if (Object.keys(requestyModels).length > 0) { - await provider.writeModelsToCache(GlobalFileNames.requestyModels, requestyModels) - await provider.postMessageToWebview({ type: "requestyModels", requestyModels }) - } - + case "requestRouterModels": + const [openRouterModels, requestyModels, glamaModels, unboundModels] = await Promise.all([ + getModels("openrouter"), + getModels("requesty"), + getModels("glama"), + getModels("unbound"), + ]) + + provider.postMessageToWebview({ + type: "routerModels", + routerModels: { + openrouter: openRouterModels, + requesty: requestyModels, + glama: glamaModels, + unbound: unboundModels, + }, + }) break - case "refreshOpenAiModels": + case "requestOpenAiModels": if (message?.values?.baseUrl && message?.values?.apiKey) { const openAiModels = await getOpenAiModels( message?.values?.baseUrl, message?.values?.apiKey, message?.values?.hostHeader, ) + provider.postMessageToWebview({ type: "openAiModels", openAiModels }) } @@ -1413,5 +1304,6 @@ const generateSystemPrompt = async (provider: ClineProvider, message: WebviewMes language, rooIgnoreInstructions, ) + return systemPrompt } diff --git a/src/exports/roo-code.d.ts b/src/exports/roo-code.d.ts index 894b776985..3eb31d3b9b 100644 --- a/src/exports/roo-code.d.ts +++ b/src/exports/roo-code.d.ts @@ -28,69 +28,9 @@ type ProviderSettings = { anthropicBaseUrl?: string | undefined anthropicUseAuthToken?: boolean | undefined glamaModelId?: string | undefined - glamaModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined glamaApiKey?: string | undefined openRouterApiKey?: string | undefined openRouterModelId?: string | undefined - openRouterModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined openRouterBaseUrl?: string | undefined openRouterSpecificProvider?: string | undefined openRouterUseMiddleOutTransform?: boolean | undefined @@ -170,68 +110,8 @@ type ProviderSettings = { deepSeekApiKey?: string | undefined unboundApiKey?: string | undefined unboundModelId?: string | undefined - unboundModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined requestyApiKey?: string | undefined requestyModelId?: string | undefined - requestyModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined xaiApiKey?: string | undefined modelMaxTokens?: number | undefined modelMaxThinkingTokens?: number | undefined diff --git a/src/exports/types.ts b/src/exports/types.ts index 4f394c2974..cda03c83d4 100644 --- a/src/exports/types.ts +++ b/src/exports/types.ts @@ -29,69 +29,9 @@ type ProviderSettings = { anthropicBaseUrl?: string | undefined anthropicUseAuthToken?: boolean | undefined glamaModelId?: string | undefined - glamaModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined glamaApiKey?: string | undefined openRouterApiKey?: string | undefined openRouterModelId?: string | undefined - openRouterModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined openRouterBaseUrl?: string | undefined openRouterSpecificProvider?: string | undefined openRouterUseMiddleOutTransform?: boolean | undefined @@ -171,68 +111,8 @@ type ProviderSettings = { deepSeekApiKey?: string | undefined unboundApiKey?: string | undefined unboundModelId?: string | undefined - unboundModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined requestyApiKey?: string | undefined requestyModelId?: string | undefined - requestyModelInfo?: - | ({ - maxTokens?: (number | null) | undefined - maxThinkingTokens?: (number | null) | undefined - contextWindow: number - supportsImages?: boolean | undefined - supportsComputerUse?: boolean | undefined - supportsPromptCache: boolean - isPromptCacheOptional?: boolean | undefined - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - description?: string | undefined - reasoningEffort?: ("low" | "medium" | "high") | undefined - thinking?: boolean | undefined - minTokensPerCachePoint?: number | undefined - maxCachePoints?: number | undefined - cachableFields?: string[] | undefined - tiers?: - | { - contextWindow: number - inputPrice?: number | undefined - outputPrice?: number | undefined - cacheWritesPrice?: number | undefined - cacheReadsPrice?: number | undefined - }[] - | undefined - } | null) - | undefined xaiApiKey?: string | undefined modelMaxTokens?: number | undefined modelMaxThinkingTokens?: number | undefined diff --git a/src/extension.ts b/src/extension.ts index aa834c560e..d895bb0e1b 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -15,6 +15,7 @@ try { import "./utils/path" // Necessary to have access to String.prototype.toPosix. import { initializeI18n } from "./i18n" +import { ContextProxy } from "./core/config/ContextProxy" import { ClineProvider } from "./core/webview/ClineProvider" import { CodeActionProvider } from "./core/CodeActionProvider" import { DIFF_VIEW_URI_SCHEME } from "./integrations/editor/DiffViewProvider" @@ -66,7 +67,8 @@ export async function activate(context: vscode.ExtensionContext) { context.globalState.update("allowedCommands", defaultCommands) } - const provider = new ClineProvider(context, outputChannel, "sidebar") + const contextProxy = await ContextProxy.getInstance(context) + const provider = new ClineProvider(context, outputChannel, "sidebar", contextProxy) telemetryService.setProvider(provider) context.subscriptions.push( diff --git a/src/schemas/index.ts b/src/schemas/index.ts index 2dc4d3f6a1..a6b2cabcda 100644 --- a/src/schemas/index.ts +++ b/src/schemas/index.ts @@ -321,12 +321,10 @@ export const providerSettingsSchema = z.object({ anthropicUseAuthToken: z.boolean().optional(), // Glama glamaModelId: z.string().optional(), - glamaModelInfo: modelInfoSchema.nullish(), glamaApiKey: z.string().optional(), // OpenRouter openRouterApiKey: z.string().optional(), openRouterModelId: z.string().optional(), - openRouterModelInfo: modelInfoSchema.nullish(), openRouterBaseUrl: z.string().optional(), openRouterSpecificProvider: z.string().optional(), openRouterUseMiddleOutTransform: z.boolean().optional(), @@ -388,11 +386,9 @@ export const providerSettingsSchema = z.object({ // Unbound unboundApiKey: z.string().optional(), unboundModelId: z.string().optional(), - unboundModelInfo: modelInfoSchema.nullish(), // Requesty requestyApiKey: z.string().optional(), requestyModelId: z.string().optional(), - requestyModelInfo: modelInfoSchema.nullish(), // X.AI (Grok) xaiApiKey: z.string().optional(), // Claude 3.7 Sonnet Thinking @@ -423,12 +419,10 @@ const providerSettingsRecord: ProviderSettingsRecord = { anthropicUseAuthToken: undefined, // Glama glamaModelId: undefined, - glamaModelInfo: undefined, glamaApiKey: undefined, // OpenRouter openRouterApiKey: undefined, openRouterModelId: undefined, - openRouterModelInfo: undefined, openRouterBaseUrl: undefined, openRouterSpecificProvider: undefined, openRouterUseMiddleOutTransform: undefined, @@ -482,11 +476,9 @@ const providerSettingsRecord: ProviderSettingsRecord = { // Unbound unboundApiKey: undefined, unboundModelId: undefined, - unboundModelInfo: undefined, // Requesty requestyApiKey: undefined, requestyModelId: undefined, - requestyModelInfo: undefined, // Claude 3.7 Sonnet Thinking modelMaxTokens: undefined, modelMaxThinkingTokens: undefined, diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index b942188345..11366cfede 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -33,24 +33,20 @@ export interface ExtensionMessage { | "action" | "state" | "selectedImages" - | "ollamaModels" - | "lmStudioModels" | "theme" | "workspaceUpdated" | "invoke" | "partialMessage" - | "openRouterModels" - | "glamaModels" - | "unboundModels" - | "requestyModels" - | "openAiModels" | "mcpServers" | "enhancedPrompt" | "commitSearchResults" | "listApiConfig" + | "routerModels" + | "openAiModels" + | "ollamaModels" + | "lmStudioModels" | "vsCodeLmModels" | "vsCodeLmApiAvailable" - | "requestVsCodeLmModels" | "updatePrompt" | "systemPrompt" | "autoApprovalEnabled" @@ -81,9 +77,6 @@ export interface ExtensionMessage { invoke?: "newChat" | "sendMessage" | "primaryButtonClick" | "secondaryButtonClick" | "setChatBoxMessage" state?: ExtensionState images?: string[] - ollamaModels?: string[] - lmStudioModels?: string[] - vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] filePaths?: string[] openedTabs?: Array<{ label: string @@ -91,11 +84,11 @@ export interface ExtensionMessage { path?: string }> partialMessage?: ClineMessage - openRouterModels?: Record - glamaModels?: Record - unboundModels?: Record - requestyModels?: Record + routerModels?: Record<"openrouter" | "requesty" | "glama" | "unbound", Record> openAiModels?: string[] + ollamaModels?: string[] + lmStudioModels?: string[] + vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] mcpServers?: McpServer[] commits?: GitCommit[] listApiConfig?: ApiConfigMeta[] @@ -106,11 +99,7 @@ export interface ExtensionMessage { values?: Record requestId?: string promptText?: string - results?: Array<{ - path: string - type: "file" | "folder" - label?: string - }> + results?: { path: string; type: "file" | "folder"; label?: string }[] error?: string } diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 6b5c111f7a..7b8c4017cd 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -40,17 +40,15 @@ export interface WebviewMessage { | "importSettings" | "exportSettings" | "resetState" + | "requestRouterModels" + | "requestOpenAiModels" | "requestOllamaModels" | "requestLmStudioModels" + | "requestVsCodeLmModels" | "openImage" | "openFile" | "openMention" | "cancelTask" - | "refreshOpenRouterModels" - | "refreshGlamaModels" - | "refreshUnboundModels" - | "refreshRequestyModels" - | "refreshOpenAiModels" | "alwaysAllowBrowser" | "alwaysAllowMcp" | "alwaysAllowModeSwitch" @@ -94,7 +92,6 @@ export interface WebviewMessage { | "alwaysApproveResubmit" | "requestDelaySeconds" | "setApiConfigPassword" - | "requestVsCodeLmModels" | "mode" | "updatePrompt" | "updateSupportPrompt" diff --git a/src/shared/globalFileNames.ts b/src/shared/globalFileNames.ts index 68990dfe95..b82ea9b00e 100644 --- a/src/shared/globalFileNames.ts +++ b/src/shared/globalFileNames.ts @@ -1,11 +1,7 @@ export const GlobalFileNames = { apiConversationHistory: "api_conversation_history.json", uiMessages: "ui_messages.json", - glamaModels: "glama_models.json", - openRouterModels: "openrouter_models.json", - requestyModels: "requesty_models.json", mcpSettings: "mcp_settings.json", - unboundModels: "unbound_models.json", customModes: "custom_modes.json", taskMetadata: "task_metadata.json", } diff --git a/webview-ui/src/components/chat/ChatView.tsx b/webview-ui/src/components/chat/ChatView.tsx index 419537b361..72d587d89d 100644 --- a/webview-ui/src/components/chat/ChatView.tsx +++ b/webview-ui/src/components/chat/ChatView.tsx @@ -24,7 +24,7 @@ import { getAllModes } from "@roo/shared/modes" import { useExtensionState } from "@src/context/ExtensionStateContext" import { vscode } from "@src/utils/vscode" -import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration" +import { useSelectedModel } from "@src/utils/normalizeApiConfiguration" import { validateCommand } from "@src/utils/command-validation" import { useAppTranslation } from "@src/i18n/TranslationContext" @@ -490,16 +490,14 @@ const ChatViewComponent: React.ForwardRefRenderFunction { - return normalizeApiConfiguration(apiConfiguration) - }, [apiConfiguration]) + const { info: model } = useSelectedModel(apiConfiguration) const selectImages = useCallback(() => { vscode.postMessage({ type: "selectImages" }) }, []) const shouldDisableImages = - !selectedModelInfo.supportsImages || textAreaDisabled || selectedImages.length >= MAX_IMAGES_PER_MESSAGE + !model?.supportsImages || textAreaDisabled || selectedImages.length >= MAX_IMAGES_PER_MESSAGE const handleMessage = useCallback( (e: MessageEvent) => { @@ -1216,7 +1214,7 @@ const ChatViewComponent: React.ForwardRefRenderFunction { const { t } = useTranslation() const { apiConfiguration, currentTaskItem } = useExtensionState() - const { selectedModelInfo } = useMemo(() => normalizeApiConfiguration(apiConfiguration), [apiConfiguration]) + const { info: model } = useSelectedModel(apiConfiguration) const [isTaskExpanded, setIsTaskExpanded] = useState(false) const textContainerRef = useRef(null) const textRef = useRef(null) - const contextWindow = selectedModelInfo?.contextWindow || 1 + const contextWindow = model?.contextWindow || 1 const { width: windowWidth } = useWindowSize() @@ -96,7 +96,7 @@ const TaskHeader = ({ {!!totalCost && ${totalCost.toFixed(2)}} @@ -132,7 +132,7 @@ const TaskHeader = ({ )} diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 052bc6e5fe..d4eab02b36 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -13,15 +13,11 @@ import { ModelInfo, azureOpenAiDefaultApiVersion, glamaDefaultModelId, - glamaDefaultModelInfo, mistralDefaultModelId, openAiModelInfoSaneDefaults, openRouterDefaultModelId, - openRouterDefaultModelInfo, unboundDefaultModelId, - unboundDefaultModelInfo, requestyDefaultModelId, - requestyDefaultModelInfo, ApiProvider, } from "@roo/shared/api" import { ExtensionMessage } from "@roo/shared/ExtensionMessage" @@ -29,7 +25,7 @@ import { AWS_REGIONS } from "@roo/shared/aws_regions" import { vscode } from "@src/utils/vscode" import { validateApiConfiguration, validateModelId, validateBedrockArn } from "@src/utils/validate" -import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration" +import { useRouterModels, useSelectedModel } from "@src/utils/normalizeApiConfiguration" import { useOpenRouterModelProviders, OPENROUTER_DEFAULT_PROVIDER_NAME, @@ -75,22 +71,6 @@ const ApiOptions = ({ const [lmStudioModels, setLmStudioModels] = useState([]) const [vsCodeLmModels, setVsCodeLmModels] = useState([]) - const [openRouterModels, setOpenRouterModels] = useState>({ - [openRouterDefaultModelId]: openRouterDefaultModelInfo, - }) - - const [glamaModels, setGlamaModels] = useState>({ - [glamaDefaultModelId]: glamaDefaultModelInfo, - }) - - const [unboundModels, setUnboundModels] = useState>({ - [unboundDefaultModelId]: unboundDefaultModelInfo, - }) - - const [requestyModels, setRequestyModels] = useState>({ - [requestyDefaultModelId]: requestyDefaultModelInfo, - }) - const [openAiModels, setOpenAiModels] = useState | null>(null) const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicBaseUrl) @@ -117,10 +97,13 @@ const ApiOptions = ({ [setApiConfigurationField], ) - const { selectedProvider, selectedModelId, selectedModelInfo } = useMemo( - () => normalizeApiConfiguration(apiConfiguration), - [apiConfiguration], - ) + const { + provider: selectedProvider, + id: selectedModelId, + info: selectedModelInfo, + } = useSelectedModel(apiConfiguration) + + const { data: routerModels } = useRouterModels() // Update apiConfiguration.aiModelId whenever selectedModelId changes. useEffect(() => { @@ -133,20 +116,9 @@ const ApiOptions = ({ // stops typing. useDebounce( () => { - if (selectedProvider === "openrouter") { - vscode.postMessage({ type: "refreshOpenRouterModels" }) - } else if (selectedProvider === "glama") { - vscode.postMessage({ type: "refreshGlamaModels" }) - } else if (selectedProvider === "unbound") { - vscode.postMessage({ type: "refreshUnboundModels" }) - } else if (selectedProvider === "requesty") { + if (selectedProvider === "openai") { vscode.postMessage({ - type: "refreshRequestyModels", - values: { apiKey: apiConfiguration?.requestyApiKey }, - }) - } else if (selectedProvider === "openai") { - vscode.postMessage({ - type: "refreshOpenAiModels", + type: "requestOpenAiModels", values: { baseUrl: apiConfiguration?.openAiBaseUrl, apiKey: apiConfiguration?.openAiApiKey, @@ -174,43 +146,23 @@ const ApiOptions = ({ useEffect(() => { const apiValidationResult = - validateApiConfiguration(apiConfiguration) || - validateModelId(apiConfiguration, glamaModels, openRouterModels, unboundModels, requestyModels) - + validateApiConfiguration(apiConfiguration) || validateModelId(apiConfiguration, routerModels) setErrorMessage(apiValidationResult) - }, [apiConfiguration, glamaModels, openRouterModels, setErrorMessage, unboundModels, requestyModels]) + }, [apiConfiguration, routerModels, setErrorMessage]) const { data: openRouterModelProviders } = useOpenRouterModelProviders(apiConfiguration?.openRouterModelId, { enabled: selectedProvider === "openrouter" && !!apiConfiguration?.openRouterModelId && - apiConfiguration.openRouterModelId in openRouterModels, + routerModels?.openrouter && + Object.keys(routerModels.openrouter).length > 1 && + apiConfiguration.openRouterModelId in routerModels.openrouter, }) const onMessage = useCallback((event: MessageEvent) => { const message: ExtensionMessage = event.data switch (message.type) { - case "openRouterModels": { - const updatedModels = message.openRouterModels ?? {} - setOpenRouterModels({ [openRouterDefaultModelId]: openRouterDefaultModelInfo, ...updatedModels }) - break - } - case "glamaModels": { - const updatedModels = message.glamaModels ?? {} - setGlamaModels({ [glamaDefaultModelId]: glamaDefaultModelInfo, ...updatedModels }) - break - } - case "unboundModels": { - const updatedModels = message.unboundModels ?? {} - setUnboundModels({ [unboundDefaultModelId]: unboundDefaultModelInfo, ...updatedModels }) - break - } - case "requestyModels": { - const updatedModels = message.requestyModels ?? {} - setRequestyModels({ [requestyDefaultModelId]: requestyDefaultModelInfo, ...updatedModels }) - break - } case "openAiModels": { const updatedModels = message.openAiModels ?? [] setOpenAiModels(Object.fromEntries(updatedModels.map((item) => [item, openAiModelInfoSaneDefaults]))) @@ -825,10 +777,8 @@ const ApiOptions = ({ apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} defaultModelId="gpt-4o" - defaultModelInfo={openAiModelInfoSaneDefaults} models={openAiModels} modelIdKey="openAiModelId" - modelInfoKey="openAiCustomModelInfo" serviceName="OpenAI" serviceUrl="https://platform.openai.com" /> @@ -1535,10 +1485,8 @@ const ApiOptions = ({ apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} defaultModelId={openRouterDefaultModelId} - defaultModelInfo={openRouterDefaultModelInfo} - models={openRouterModels} + models={routerModels?.openrouter ?? {}} modelIdKey="openRouterModelId" - modelInfoKey="openRouterModelInfo" serviceName="OpenRouter" serviceUrl="https://openrouter.ai/models" /> @@ -1558,16 +1506,7 @@ const ApiOptions = ({ , + VSCodeRadio: ({ value, checked }: any) => , VSCodeRadioGroup: ({ children }: any) =>
{children}
, VSCodeButton: ({ children }: any) =>
{children}
, })) @@ -54,6 +56,11 @@ jest.mock("@/components/ui", () => ({ {children} ), + Slider: ({ value, onChange }: any) => ( +
+ onChange(parseFloat(e.target.value))} /> +
+ ), })) jest.mock("../TemperatureControl", () => ({ @@ -86,16 +93,6 @@ jest.mock("../RateLimitSecondsControl", () => ({ ), })) -// Mock ThinkingBudget component -jest.mock("../ThinkingBudget", () => ({ - ThinkingBudget: ({ apiConfiguration, setApiConfigurationField, modelInfo, provider }: any) => - modelInfo?.thinking ? ( -
- -
- ) : null, -})) - // Mock DiffSettingsControl for tests jest.mock("../DiffSettingsControl", () => ({ DiffSettingsControl: ({ diffEnabled, fuzzyMatchThreshold, onChange }: any) => ( @@ -123,7 +120,23 @@ jest.mock("../DiffSettingsControl", () => ({ ), })) -const renderApiOptions = (props = {}) => { +jest.mock("@src/components/ui/hooks/useSelectedModel", () => ({ + useSelectedModel: jest.fn((apiConfiguration: ApiConfiguration) => { + if (apiConfiguration.apiModelId?.includes("thinking")) { + return { + provider: apiConfiguration.apiProvider, + info: { thinking: true, contextWindow: 4000, maxTokens: 128000 }, + } + } else { + return { + provider: apiConfiguration.apiProvider, + info: { contextWindow: 4000 }, + } + } + }), +})) + +const renderApiOptions = (props: Partial = {}) => { const queryClient = new QueryClient() render( @@ -192,7 +205,6 @@ describe("ApiOptions", () => { apiConfiguration: { apiProvider: "anthropic", apiModelId: "claude-3-opus-20240229", - modelInfo: { thinking: false }, // Non-thinking model }, }) diff --git a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx index 9225c9f76d..710ca3e5ef 100644 --- a/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx +++ b/webview-ui/src/components/settings/__tests__/ModelPicker.test.tsx @@ -2,6 +2,9 @@ import { screen, fireEvent, render } from "@testing-library/react" import { act } from "react" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + +import { ModelInfo } from "@roo/schemas" import { ModelPicker } from "../ModelPicker" @@ -21,7 +24,8 @@ Element.prototype.scrollIntoView = jest.fn() describe("ModelPicker", () => { const mockSetApiConfigurationField = jest.fn() - const modelInfo = { + + const modelInfo: ModelInfo = { maxTokens: 8192, contextWindow: 200_000, supportsImages: true, @@ -32,14 +36,15 @@ describe("ModelPicker", () => { cacheWritesPrice: 3.75, cacheReadsPrice: 0.3, } + const mockModels = { model1: { name: "Model 1", description: "Test model 1", ...modelInfo }, model2: { name: "Model 2", description: "Test model 2", ...modelInfo }, } + const defaultProps = { apiConfiguration: {}, defaultModelId: "model1", - defaultModelInfo: modelInfo, modelIdKey: "glamaModelId" as const, serviceName: "Test Service", serviceUrl: "https://test.service", @@ -48,14 +53,22 @@ describe("ModelPicker", () => { setApiConfigurationField: mockSetApiConfigurationField, } + const queryClient = new QueryClient() + + const renderModelPicker = () => { + return render( + + + , + ) + } + beforeEach(() => { jest.clearAllMocks() }) it("calls setApiConfigurationField when a model is selected", async () => { - await act(async () => { - render() - }) + await act(async () => renderModelPicker()) await act(async () => { // Open the popover by clicking the button. @@ -86,9 +99,7 @@ describe("ModelPicker", () => { }) it("allows setting a custom model ID that's not in the predefined list", async () => { - await act(async () => { - render() - }) + await act(async () => renderModelPicker()) await act(async () => { // Open the popover by clicking the button. @@ -123,7 +134,5 @@ describe("ModelPicker", () => { // Verify the API config was updated with the custom model ID expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, customModelId) - // The model info should be set to the default since this is a custom model - expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.defaultModelInfo) }) }) diff --git a/webview-ui/src/components/ui/hooks/useRouterModels.ts b/webview-ui/src/components/ui/hooks/useRouterModels.ts new file mode 100644 index 0000000000..8140b8e533 --- /dev/null +++ b/webview-ui/src/components/ui/hooks/useRouterModels.ts @@ -0,0 +1,40 @@ +import { ModelInfo } from "@roo/shared/api" + +import { vscode } from "@src/utils/vscode" +import { ExtensionMessage } from "@roo/shared/ExtensionMessage" +import { useQuery } from "@tanstack/react-query" + +export type RouterModels = Record<"openrouter" | "requesty" | "glama" | "unbound", Record> + +const getRouterModels = async () => + new Promise((resolve, reject) => { + const cleanup = () => { + window.removeEventListener("message", handler) + } + + const timeout = setTimeout(() => { + cleanup() + reject(new Error("Router models request timed out")) + }, 10000) + + const handler = (event: MessageEvent) => { + const message: ExtensionMessage = event.data + + if (message.type === "routerModels") { + clearTimeout(timeout) + cleanup() + + if (message.routerModels) { + console.log("message.routerModels", message.routerModels) + resolve(message.routerModels) + } else { + reject(new Error("No router models in response")) + } + } + } + + window.addEventListener("message", handler) + vscode.postMessage({ type: "requestRouterModels" }) + }) + +export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels }) diff --git a/webview-ui/src/utils/normalizeApiConfiguration.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts similarity index 78% rename from webview-ui/src/utils/normalizeApiConfiguration.ts rename to webview-ui/src/components/ui/hooks/useSelectedModel.ts index bf7f4f21df..89f9e3a5ca 100644 --- a/webview-ui/src/utils/normalizeApiConfiguration.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -26,44 +26,7 @@ import { unboundDefaultModelId, } from "@roo/shared/api" -import { vscode } from "@src/utils/vscode" -import { ExtensionMessage } from "@roo/shared/ExtensionMessage" -import { useQuery } from "@tanstack/react-query" - -type RouterModels = Record<"openrouter" | "requesty" | "glama" | "unbound", Record> - -export const getRouterModels = async () => - new Promise((resolve, reject) => { - const cleanup = () => { - window.removeEventListener("message", handler) - } - - const timeout = setTimeout(() => { - cleanup() - reject(new Error("Router models request timed out")) - }, 10000) - - const handler = (event: MessageEvent) => { - const message: ExtensionMessage = event.data - - if (message.type === "routerModels") { - clearTimeout(timeout) - cleanup() - - if (message.routerModels) { - console.log("message.routerModels", message.routerModels) - resolve(message.routerModels) - } else { - reject(new Error("No router models in response")) - } - } - } - - window.addEventListener("message", handler) - vscode.postMessage({ type: "requestRouterModels" }) - }) - -export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels }) +import { type RouterModels, useRouterModels } from "./useRouterModels" export const useSelectedModel = (apiConfiguration?: ApiConfiguration) => { const { data: routerModels, isLoading, isError } = useRouterModels() From 31e43966d239834d9f382d5cbdab08a0e183b227 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 12:57:40 -0700 Subject: [PATCH 05/15] Fix more tests --- webview-ui/src/components/chat/ChatView.tsx | 2 +- .../chat/__tests__/ChatView.test.tsx | 210 ++++-------------- 2 files changed, 44 insertions(+), 168 deletions(-) diff --git a/webview-ui/src/components/chat/ChatView.tsx b/webview-ui/src/components/chat/ChatView.tsx index e9cd054c57..25dffe9dcb 100644 --- a/webview-ui/src/components/chat/ChatView.tsx +++ b/webview-ui/src/components/chat/ChatView.tsx @@ -40,7 +40,7 @@ import TaskHeader from "./TaskHeader" import AutoApproveMenu from "./AutoApproveMenu" import SystemPromptWarning from "./SystemPromptWarning" -interface ChatViewProps { +export interface ChatViewProps { isHidden: boolean showAnnouncement: boolean hideAnnouncement: () => void diff --git a/webview-ui/src/components/chat/__tests__/ChatView.test.tsx b/webview-ui/src/components/chat/__tests__/ChatView.test.tsx index adc19ca0bb..13d9433f8d 100644 --- a/webview-ui/src/components/chat/__tests__/ChatView.test.tsx +++ b/webview-ui/src/components/chat/__tests__/ChatView.test.tsx @@ -1,9 +1,14 @@ +// npx jest src/components/chat/__tests__/ChatView.test.tsx + import React from "react" import { render, waitFor, act } from "@testing-library/react" -import ChatView from "../ChatView" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + import { ExtensionStateContextProvider } from "@src/context/ExtensionStateContext" import { vscode } from "@src/utils/vscode" +import ChatView, { ChatViewProps } from "../ChatView" + // Define minimal types needed for testing interface ClineMessage { type: "say" | "ask" @@ -85,13 +90,6 @@ jest.mock("../ChatTextArea", () => { } }) -jest.mock("../TaskHeader", () => ({ - __esModule: true, - default: function MockTaskHeader({ task }: { task: ClineMessage }) { - return
{JSON.stringify(task)}
- }, -})) - // Mock VSCode components jest.mock("@vscode/webview-ui-toolkit/react", () => ({ VSCodeButton: function MockVSCodeButton({ @@ -151,22 +149,30 @@ const mockPostMessage = (state: Partial) => { ) } +const defaultProps: ChatViewProps = { + isHidden: false, + showAnnouncement: false, + hideAnnouncement: () => {}, + showHistoryView: () => {}, +} + +const queryClient = new QueryClient() + +const renderChatView = (props: Partial = {}) => { + return render( + + + + + , + ) +} + describe("ChatView - Auto Approval Tests", () => { - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) it("does not auto-approve any actions when autoApprovalEnabled is false", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -240,16 +246,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves browser actions when alwaysAllowBrowser is enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -296,16 +293,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves read-only tools when alwaysAllowReadOnly is enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -353,16 +341,7 @@ describe("ChatView - Auto Approval Tests", () => { describe("Write Tool Auto-Approval Tests", () => { it("auto-approves write tools when alwaysAllowWrite is enabled and message is a tool request", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -411,16 +390,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve write operations when alwaysAllowWrite is enabled but message is not a tool request", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -466,16 +436,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves allowed commands when alwaysAllowExecute is enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -524,16 +485,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve disallowed commands even when alwaysAllowExecute is enabled", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -581,16 +533,7 @@ describe("ChatView - Auto Approval Tests", () => { describe("Command Chaining Tests", () => { it("auto-approves chained commands when all parts are allowed", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // Test various allowed command chaining scenarios const allowedChainedCommands = [ @@ -656,16 +599,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve chained commands when any part is disallowed", () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // Test various command chaining scenarios with disallowed parts const disallowedChainedCommands = [ @@ -728,16 +662,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("handles complex PowerShell command chains correctly", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // Test PowerShell specific command chains const powershellCommands = { @@ -849,21 +774,10 @@ describe("ChatView - Auto Approval Tests", () => { }) describe("ChatView - Sound Playing Tests", () => { - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) it("does not play sound for auto-approved browser actions", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -915,16 +829,7 @@ describe("ChatView - Sound Playing Tests", () => { }) it("plays notification sound for non-auto-approved browser actions", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -978,16 +883,7 @@ describe("ChatView - Sound Playing Tests", () => { }) it("plays celebration sound for completion results", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -1037,16 +933,7 @@ describe("ChatView - Sound Playing Tests", () => { }) it("plays progress_loop sound for api failures", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ @@ -1097,9 +984,7 @@ describe("ChatView - Sound Playing Tests", () => { }) describe("ChatView - Focus Grabbing Tests", () => { - beforeEach(() => { - jest.clearAllMocks() - }) + beforeEach(() => jest.clearAllMocks()) it("does not grab focus when follow-up question presented", async () => { const sleep = async (timeout: number) => { @@ -1108,16 +993,7 @@ describe("ChatView - Focus Grabbing Tests", () => { }) } - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task and streaming mockPostMessage({ From dcbd2fe3c42e9891ca82da1ff7b33b6d3dd2d388 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 13:00:16 -0700 Subject: [PATCH 06/15] Fix more tests --- .../__tests__/ChatView.auto-approve.test.tsx | 148 +++++------------- 1 file changed, 36 insertions(+), 112 deletions(-) diff --git a/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx b/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx index f1afd66b26..704c705d39 100644 --- a/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx +++ b/webview-ui/src/components/chat/__tests__/ChatView.auto-approve.test.tsx @@ -1,9 +1,13 @@ -import React from "react" +// npx jest src/components/chat/__tests__/ChatView.auto-approve.test.tsx + import { render, waitFor } from "@testing-library/react" -import ChatView from "../ChatView" +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + import { ExtensionStateContextProvider } from "@src/context/ExtensionStateContext" import { vscode } from "@src/utils/vscode" +import ChatView, { ChatViewProps } from "../ChatView" + // Mock vscode API jest.mock("@src/utils/vscode", () => ({ vscode: { @@ -85,22 +89,32 @@ const mockPostMessage = (state: any) => { ) } +const queryClient = new QueryClient() + +const defaultProps: ChatViewProps = { + isHidden: false, + showAnnouncement: false, + hideAnnouncement: () => {}, + showHistoryView: () => {}, +} + +const renderChatView = (props: Partial = {}) => { + return render( + + + + + , + ) +} + describe("ChatView - Auto Approval Tests", () => { beforeEach(() => { jest.clearAllMocks() }) it("auto-approves read operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -147,16 +161,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves outside workspace read operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -235,16 +240,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve outside workspace read operations without permission", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -297,16 +293,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve when autoApprovalEnabled is false", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -351,16 +338,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves write operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -409,16 +387,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves outside workspace write operations when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -474,16 +443,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve outside workspace write operations without permission", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -539,16 +499,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves browser actions when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -595,16 +546,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("auto-approves mode switch when enabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -651,16 +593,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve mode switch when disabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ @@ -705,16 +638,7 @@ describe("ChatView - Auto Approval Tests", () => { }) it("does not auto-approve mode switch when auto-approval is disabled", async () => { - render( - - {}} - showHistoryView={() => {}} - /> - , - ) + renderChatView() // First hydrate state with initial task mockPostMessage({ From 58ecbdc5d83bbc88f85670e8f8ee9b5fda6856e8 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 13:02:51 -0700 Subject: [PATCH 07/15] Make knip happy --- src/api/providers/fetchers/index.ts | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 src/api/providers/fetchers/index.ts diff --git a/src/api/providers/fetchers/index.ts b/src/api/providers/fetchers/index.ts deleted file mode 100644 index 24644c6542..0000000000 --- a/src/api/providers/fetchers/index.ts +++ /dev/null @@ -1,4 +0,0 @@ -export { getOpenRouterModels } from "./openrouter" -export { getRequestyModels } from "./requesty" -export { getGlamaModels } from "./glama" -export { getUnboundModels } from "./unbound" From af3addeb5cf102ccc82e3773ff21633d55ddec9a Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 13:14:57 -0700 Subject: [PATCH 08/15] Fix tests --- src/api/providers/__tests__/unbound.test.ts | 72 +++++++++------------ src/api/providers/router-provider.ts | 2 +- src/shared/api.ts | 2 +- 3 files changed, 32 insertions(+), 44 deletions(-) diff --git a/src/api/providers/__tests__/unbound.test.ts b/src/api/providers/__tests__/unbound.test.ts index 8c93a4ff01..e174eb7e99 100644 --- a/src/api/providers/__tests__/unbound.test.ts +++ b/src/api/providers/__tests__/unbound.test.ts @@ -21,12 +21,7 @@ jest.mock("openai", () => { [Symbol.asyncIterator]: async function* () { // First chunk with content yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], + choices: [{ delta: { content: "Test response" }, index: 0 }], } // Second chunk with usage data yield { @@ -52,15 +47,14 @@ jest.mock("openai", () => { } const result = mockCreate(...args) + if (args[0].stream) { mockWithResponse.mockReturnValue( - Promise.resolve({ - data: stream, - response: { headers: new Map() }, - }), + Promise.resolve({ data: stream, response: { headers: new Map() } }), ) result.withResponse = mockWithResponse } + return result }, }, @@ -75,10 +69,10 @@ describe("UnboundHandler", () => { beforeEach(() => { mockOptions = { - apiModelId: "anthropic/claude-3-5-sonnet-20241022", unboundApiKey: "test-api-key", unboundModelId: "anthropic/claude-3-5-sonnet-20241022", } + handler = new UnboundHandler(mockOptions) mockCreate.mockClear() mockWithResponse.mockClear() @@ -97,9 +91,9 @@ describe("UnboundHandler", () => { }) describe("constructor", () => { - it("should initialize with provided options", () => { + it("should initialize with provided options", async () => { expect(handler).toBeInstanceOf(UnboundHandler) - expect(handler.getModel().id).toBe(mockOptions.apiModelId) + expect((await handler.fetchModel()).id).toBe(mockOptions.unboundModelId) }) }) @@ -115,6 +109,7 @@ describe("UnboundHandler", () => { it("should handle streaming responses with text and usage data", async () => { const stream = handler.createMessage(systemPrompt, messages) const chunks: Array<{ type: string } & Record> = [] + for await (const chunk of stream) { chunks.push(chunk) } @@ -122,17 +117,10 @@ describe("UnboundHandler", () => { expect(chunks.length).toBe(3) // Verify text chunk - expect(chunks[0]).toEqual({ - type: "text", - text: "Test response", - }) + expect(chunks[0]).toEqual({ type: "text", text: "Test response" }) // Verify regular usage data - expect(chunks[1]).toEqual({ - type: "usage", - inputTokens: 10, - outputTokens: 5, - }) + expect(chunks[1]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 5 }) // Verify usage data with cache information expect(chunks[2]).toEqual({ @@ -149,6 +137,7 @@ describe("UnboundHandler", () => { messages: expect.any(Array), stream: true, }), + expect.objectContaining({ headers: { "X-Unbound-Metadata": expect.stringContaining("roo-code"), @@ -169,6 +158,7 @@ describe("UnboundHandler", () => { for await (const chunk of stream) { chunks.push(chunk) } + fail("Expected error to be thrown") } catch (error) { expect(error).toBeInstanceOf(Error) @@ -181,6 +171,7 @@ describe("UnboundHandler", () => { it("should complete prompt successfully", async () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") + expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: "claude-3-5-sonnet-20241022", @@ -202,9 +193,7 @@ describe("UnboundHandler", () => { }) it("should handle empty response", async () => { - mockCreate.mockResolvedValueOnce({ - choices: [{ message: { content: "" } }], - }) + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "" } }] }) const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) @@ -212,14 +201,14 @@ describe("UnboundHandler", () => { it("should not set max_tokens for non-Anthropic models", async () => { mockCreate.mockClear() - const nonAnthropicOptions = { + const nonAnthropicHandler = new UnboundHandler({ apiModelId: "openai/gpt-4o", unboundApiKey: "test-key", unboundModelId: "openai/gpt-4o", - } - const nonAnthropicHandler = new UnboundHandler(nonAnthropicOptions) + }) await nonAnthropicHandler.completePrompt("Test prompt") + expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: "gpt-4o", @@ -232,20 +221,21 @@ describe("UnboundHandler", () => { }), }), ) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("max_tokens") }) it("should not set temperature for openai/o3-mini", async () => { mockCreate.mockClear() - const openaiOptions = { + const openaiHandler = new UnboundHandler({ apiModelId: "openai/o3-mini", unboundApiKey: "test-key", unboundModelId: "openai/o3-mini", - } - const openaiHandler = new UnboundHandler(openaiOptions) + }) await openaiHandler.completePrompt("Test prompt") + expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: "o3-mini", @@ -257,24 +247,22 @@ describe("UnboundHandler", () => { }), }), ) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty("temperature") }) }) - describe("getModel", () => { - it("should return model info", () => { - const modelInfo = handler.getModel() - expect(modelInfo.id).toBe(mockOptions.apiModelId) + describe("fetchModel", () => { + it("should return model info", async () => { + const modelInfo = await handler.fetchModel() + expect(modelInfo.id).toBe(mockOptions.unboundModelId) expect(modelInfo.info).toBeDefined() }) - it("should return default model when invalid model provided", () => { - const handlerWithInvalidModel = new UnboundHandler({ - ...mockOptions, - unboundModelId: "invalid/model", - }) - const modelInfo = handlerWithInvalidModel.getModel() - expect(modelInfo.id).toBe("anthropic/claude-3-5-sonnet-20241022") // Default model + it("should return default model when invalid model provided", async () => { + const handlerWithInvalidModel = new UnboundHandler({ ...mockOptions, unboundModelId: "invalid/model" }) + const modelInfo = await handlerWithInvalidModel.fetchModel() + expect(modelInfo.id).toBe("anthropic/claude-3-7-sonnet-20250219") expect(modelInfo.info).toBeDefined() }) }) diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts index cdd574ab39..3c1c8e7a16 100644 --- a/src/api/providers/router-provider.ts +++ b/src/api/providers/router-provider.ts @@ -43,7 +43,7 @@ export abstract class RouterProvider extends BaseProvider { this.client = new OpenAI({ baseURL, apiKey }) } - protected async fetchModel() { + public async fetchModel() { this.models = await getModels(this.name) return this.getModel() } diff --git a/src/shared/api.ts b/src/shared/api.ts index 3827a5a3b2..b5d38e421b 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1087,7 +1087,7 @@ export const mistralModels = { // Unbound Security // https://www.unboundsecurity.ai/ai-gateway -export const unboundDefaultModelId = "anthropic/claude-3-5-sonnet-20241022" +export const unboundDefaultModelId = "anthropic/claude-3-7-sonnet-20250219" export const unboundDefaultModelInfo: ModelInfo = { maxTokens: 8192, contextWindow: 200_000, From de3e2a525e8340d918faa54873fe9d786f3f2993 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 13:18:00 -0700 Subject: [PATCH 09/15] Test tweaks --- src/api/providers/__tests__/glama.test.ts | 45 ++++++++++------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/api/providers/__tests__/glama.test.ts b/src/api/providers/__tests__/glama.test.ts index 93edde8748..c7903a8e55 100644 --- a/src/api/providers/__tests__/glama.test.ts +++ b/src/api/providers/__tests__/glama.test.ts @@ -19,31 +19,18 @@ jest.mock("openai", () => { const stream = { [Symbol.asyncIterator]: async function* () { yield { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], + choices: [{ delta: { content: "Test response" }, index: 0 }], usage: null, } yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, + choices: [{ delta: {}, index: 0 }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, } }, } const result = mockCreate(...args) + if (args[0].stream) { mockWithResponse.mockReturnValue( Promise.resolve({ @@ -58,6 +45,7 @@ jest.mock("openai", () => { ) result.withResponse = mockWithResponse } + return result }, }, @@ -72,10 +60,10 @@ describe("GlamaHandler", () => { beforeEach(() => { mockOptions = { - apiModelId: "anthropic/claude-3-7-sonnet", - glamaModelId: "anthropic/claude-3-7-sonnet", glamaApiKey: "test-api-key", + glamaModelId: "anthropic/claude-3-7-sonnet", } + handler = new GlamaHandler(mockOptions) mockCreate.mockClear() mockWithResponse.mockClear() @@ -101,7 +89,7 @@ describe("GlamaHandler", () => { describe("constructor", () => { it("should initialize with provided options", () => { expect(handler).toBeInstanceOf(GlamaHandler) - expect(handler.getModel().id).toBe(mockOptions.apiModelId) + expect(handler.getModel().id).toBe(mockOptions.glamaModelId) }) }) @@ -152,7 +140,7 @@ describe("GlamaHandler", () => { expect(result).toBe("Test response") expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ - model: mockOptions.apiModelId, + model: mockOptions.glamaModelId, messages: [{ role: "user", content: "Test prompt" }], temperature: 0, max_tokens: 8192, @@ -196,13 +184,20 @@ describe("GlamaHandler", () => { }) }) - describe("getModel", () => { - it("should return model info", () => { - const modelInfo = handler.getModel() - expect(modelInfo.id).toBe(mockOptions.apiModelId) + describe("fetchModel", () => { + it("should return model info", async () => { + const modelInfo = await handler.fetchModel() + expect(modelInfo.id).toBe(mockOptions.glamaModelId) expect(modelInfo.info).toBeDefined() expect(modelInfo.info.maxTokens).toBe(8192) expect(modelInfo.info.contextWindow).toBe(200_000) }) + + it("should return default model when invalid model provided", async () => { + const handlerWithInvalidModel = new GlamaHandler({ ...mockOptions, glamaModelId: "invalid/model" }) + const modelInfo = await handlerWithInvalidModel.fetchModel() + expect(modelInfo.id).toBe("anthropic/claude-3-7-sonnet") + expect(modelInfo.info).toBeDefined() + }) }) }) From 02274efb100cd4951b2ecf0372a67d8934784e5d Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 13:34:12 -0700 Subject: [PATCH 10/15] Revert this --- src/core/webview/ClineProvider.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 8b9f6b533a..b27853e35a 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -335,12 +335,13 @@ export class ClineProvider extends EventEmitter implements async resolveWebviewView(webviewView: vscode.WebviewView | vscode.WebviewPanel) { this.log("Resolving webview view") - this.view = webviewView if (!this.contextProxy.isInitialized) { await this.contextProxy.initialize() } + this.view = webviewView + // Set panel reference according to webview type if ("onDidChangeViewState" in webviewView) { // Tag page type From 1bf7c790443f4107041633f494eaebf5a61ddd2a Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 13:40:52 -0700 Subject: [PATCH 11/15] Sync evals types to extension types --- evals/apps/web/src/app/runs/new/new-run.tsx | 3 +-- evals/packages/types/src/roo-code.ts | 8 -------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/evals/apps/web/src/app/runs/new/new-run.tsx b/evals/apps/web/src/app/runs/new/new-run.tsx index 38759eaa87..71b7422ff3 100644 --- a/evals/apps/web/src/app/runs/new/new-run.tsx +++ b/evals/apps/web/src/app/runs/new/new-run.tsx @@ -94,8 +94,7 @@ export function NewRun() { } const openRouterModelId = openRouterModel.id - const openRouterModelInfo = openRouterModel.modelInfo - values.settings = { ...(values.settings || {}), openRouterModelId, openRouterModelInfo } + values.settings = { ...(values.settings || {}), openRouterModelId } } const { id } = await createRun(values) diff --git a/evals/packages/types/src/roo-code.ts b/evals/packages/types/src/roo-code.ts index 2c467f6a9b..3984d54882 100644 --- a/evals/packages/types/src/roo-code.ts +++ b/evals/packages/types/src/roo-code.ts @@ -304,12 +304,10 @@ export const providerSettingsSchema = z.object({ anthropicUseAuthToken: z.boolean().optional(), // Glama glamaModelId: z.string().optional(), - glamaModelInfo: modelInfoSchema.optional(), glamaApiKey: z.string().optional(), // OpenRouter openRouterApiKey: z.string().optional(), openRouterModelId: z.string().optional(), - openRouterModelInfo: modelInfoSchema.optional(), openRouterBaseUrl: z.string().optional(), openRouterSpecificProvider: z.string().optional(), openRouterUseMiddleOutTransform: z.boolean().optional(), @@ -371,11 +369,9 @@ export const providerSettingsSchema = z.object({ // Unbound unboundApiKey: z.string().optional(), unboundModelId: z.string().optional(), - unboundModelInfo: modelInfoSchema.optional(), // Requesty requestyApiKey: z.string().optional(), requestyModelId: z.string().optional(), - requestyModelInfo: modelInfoSchema.optional(), // Claude 3.7 Sonnet Thinking modelMaxTokens: z.number().optional(), // Currently only used by Anthropic hybrid thinking models. modelMaxThinkingTokens: z.number().optional(), // Currently only used by Anthropic hybrid thinking models. @@ -401,12 +397,10 @@ const providerSettingsRecord: ProviderSettingsRecord = { anthropicUseAuthToken: undefined, // Glama glamaModelId: undefined, - glamaModelInfo: undefined, glamaApiKey: undefined, // OpenRouter openRouterApiKey: undefined, openRouterModelId: undefined, - openRouterModelInfo: undefined, openRouterBaseUrl: undefined, openRouterSpecificProvider: undefined, openRouterUseMiddleOutTransform: undefined, @@ -460,11 +454,9 @@ const providerSettingsRecord: ProviderSettingsRecord = { // Unbound unboundApiKey: undefined, unboundModelId: undefined, - unboundModelInfo: undefined, // Requesty requestyApiKey: undefined, requestyModelId: undefined, - requestyModelInfo: undefined, // Claude 3.7 Sonnet Thinking modelMaxTokens: undefined, modelMaxThinkingTokens: undefined, From a2a1c78863c93f8589f27a7ce6de38c46196ba18 Mon Sep 17 00:00:00 2001 From: Chris Estreich Date: Fri, 25 Apr 2025 13:45:05 -0700 Subject: [PATCH 12/15] Update webview-ui/src/components/ui/hooks/useSelectedModel.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- webview-ui/src/components/ui/hooks/useSelectedModel.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 89f9e3a5ca..11c018a43f 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -118,6 +118,6 @@ function getSelectedModelInfo({ supportsImages: false, // VSCode LM API currently doesn't support images. } default: - return anthropicModels[id as keyof typeof anthropicModels] ?? anthropicDefaultModelId + return anthropicModels[id as keyof typeof anthropicModels] ?? anthropicModels[anthropicDefaultModelId] } } From 341e83694d4f03bcff7f1528f4586493ec4df168 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 14:09:35 -0700 Subject: [PATCH 13/15] Pull caching logic into transforms --- src/api/providers/glama.ts | 33 ++++----------------------------- src/api/providers/unbound.ts | 29 +++-------------------------- src/api/transform/caching.ts | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 55 deletions(-) create mode 100644 src/api/transform/caching.ts diff --git a/src/api/providers/glama.ts b/src/api/providers/glama.ts index 7cc366b432..72109f6672 100644 --- a/src/api/providers/glama.ts +++ b/src/api/providers/glama.ts @@ -3,8 +3,9 @@ import axios from "axios" import OpenAI from "openai" import { ApiHandlerOptions, glamaDefaultModelId, glamaDefaultModelInfo } from "../../shared/api" -import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { addCacheControlDirectives } from "../transform/caching" import { SingleCompletionHandler } from "../index" import { RouterProvider } from "./router-provider" @@ -36,31 +37,7 @@ export class GlamaHandler extends RouterProvider implements SingleCompletionHand ] if (modelId.startsWith("anthropic/claude-3")) { - openAiMessages[0] = { - role: "system", - // @ts-ignore-next-line - content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }], - } - - const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) - - lastTwoUserMessages.forEach((msg) => { - if (typeof msg.content === "string") { - msg.content = [{ type: "text", text: msg.content }] - } - - if (Array.isArray(msg.content)) { - let lastTextPart = msg.content.filter((part) => part.type === "text").pop() - - if (!lastTextPart) { - lastTextPart = { type: "text", text: "..." } - msg.content.push(lastTextPart) - } - - // @ts-ignore-next-line - lastTextPart["cache_control"] = { type: "ephemeral" } - } - }) + addCacheControlDirectives(systemPrompt, openAiMessages) } // Required by Anthropic; other providers default to max tokens allowed. @@ -82,9 +59,7 @@ export class GlamaHandler extends RouterProvider implements SingleCompletionHand } const { data: completion, response } = await this.client.chat.completions - .create(requestOptions, { - headers: DEFAULT_HEADERS, - }) + .create(requestOptions, { headers: DEFAULT_HEADERS }) .withResponse() const completionRequestId = response.headers.get("x-completion-request-id") diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index 1de6667ff7..27c10313a6 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -2,8 +2,9 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" import { ApiHandlerOptions, unboundDefaultModelId, unboundDefaultModelInfo } from "../../shared/api" -import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { addCacheControlDirectives } from "../transform/caching" import { SingleCompletionHandler } from "../index" import { RouterProvider } from "./router-provider" @@ -38,31 +39,7 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa ] if (modelId.startsWith("anthropic/claude-3")) { - openAiMessages[0] = { - role: "system", - // @ts-ignore-next-line - content: [{ type: "text", text: systemPrompt, cache_control: { type: "ephemeral" } }], - } - - const lastTwoUserMessages = openAiMessages.filter((msg) => msg.role === "user").slice(-2) - - lastTwoUserMessages.forEach((msg) => { - if (typeof msg.content === "string") { - msg.content = [{ type: "text", text: msg.content }] - } - - if (Array.isArray(msg.content)) { - let lastTextPart = msg.content.filter((part) => part.type === "text").pop() - - if (!lastTextPart) { - lastTextPart = { type: "text", text: "..." } - msg.content.push(lastTextPart) - } - - // @ts-ignore-next-line - lastTextPart["cache_control"] = { type: "ephemeral" } - } - }) + addCacheControlDirectives(systemPrompt, openAiMessages) } // Required by Anthropic; other providers default to max tokens allowed. diff --git a/src/api/transform/caching.ts b/src/api/transform/caching.ts new file mode 100644 index 0000000000..0a8ae6bf45 --- /dev/null +++ b/src/api/transform/caching.ts @@ -0,0 +1,36 @@ +import OpenAI from "openai" + +export const addCacheControlDirectives = (systemPrompt: string, messages: OpenAI.Chat.ChatCompletionMessageParam[]) => { + messages[0] = { + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + // @ts-ignore-next-line + cache_control: { type: "ephemeral" }, + }, + ], + } + + messages + .filter((msg) => msg.role === "user") + .slice(-2) + .forEach((msg) => { + if (typeof msg.content === "string") { + msg.content = [{ type: "text", text: msg.content }] + } + + if (Array.isArray(msg.content)) { + let lastTextPart = msg.content.filter((part) => part.type === "text").pop() + + if (!lastTextPart) { + lastTextPart = { type: "text", text: "..." } + msg.content.push(lastTextPart) + } + + // @ts-ignore-next-line + lastTextPart["cache_control"] = { type: "ephemeral" } + } + }) +} From 1170466f0db801224d6b5b71b314792fe5aacaca Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 14:41:09 -0700 Subject: [PATCH 14/15] Fix flakes --- src/api/providers/__tests__/openrouter.test.ts | 2 +- .../checkpoints/__tests__/ShadowCheckpointService.test.ts | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/api/providers/__tests__/openrouter.test.ts b/src/api/providers/__tests__/openrouter.test.ts index b53a276f39..8cf500930f 100644 --- a/src/api/providers/__tests__/openrouter.test.ts +++ b/src/api/providers/__tests__/openrouter.test.ts @@ -32,7 +32,7 @@ describe("OpenRouterHandler", () => { }) }) - describe("getModel", () => { + describe("fetchModel", () => { it("returns correct model info when options are provided", async () => { const handler = new OpenRouterHandler(mockOptions) const result = await handler.fetchModel() diff --git a/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts b/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts index 6e42cfae07..84589c5fd2 100644 --- a/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts +++ b/src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts @@ -12,6 +12,8 @@ import * as fileSearch from "../../../services/search/file-search" import { RepoPerTaskCheckpointService } from "../RepoPerTaskCheckpointService" +jest.setTimeout(10_000) + const tmpDir = path.join(os.tmpdir(), "CheckpointService") const initWorkspaceRepo = async ({ From cd58aade733d1c9a848ddb9f745e2f463ae34960 Mon Sep 17 00:00:00 2001 From: cte Date: Fri, 25 Apr 2025 15:08:55 -0700 Subject: [PATCH 15/15] More cleanup --- src/api/providers/fetchers/cache.ts | 16 ++-- src/api/providers/openrouter.ts | 6 +- src/api/providers/requesty.ts | 4 +- src/api/providers/router-provider.ts | 4 +- src/shared/ExtensionMessage.ts | 7 +- src/shared/api.ts | 10 ++ .../settings/__tests__/SettingsView.test.tsx | 10 +- .../components/ui/hooks/useRouterModels.ts | 4 +- .../components/ui/hooks/useSelectedModel.ts | 3 +- .../src/components/welcome/WelcomeView.tsx | 2 +- webview-ui/src/utils/validate.ts | 95 +++++-------------- 11 files changed, 56 insertions(+), 105 deletions(-) diff --git a/src/api/providers/fetchers/cache.ts b/src/api/providers/fetchers/cache.ts index 890ee91a1a..ab6dcce021 100644 --- a/src/api/providers/fetchers/cache.ts +++ b/src/api/providers/fetchers/cache.ts @@ -5,21 +5,15 @@ import NodeCache from "node-cache" import { ContextProxy } from "../../../core/config/ContextProxy" import { getCacheDirectoryPath } from "../../../shared/storagePathManager" +import { RouterName, ModelRecord } from "../../../shared/api" import { fileExistsAtPath } from "../../../utils/fs" -import type { ModelInfo } from "../../../schemas" + import { getOpenRouterModels } from "./openrouter" import { getRequestyModels } from "./requesty" import { getGlamaModels } from "./glama" import { getUnboundModels } from "./unbound" -export type RouterName = "openrouter" | "requesty" | "glama" | "unbound" - -export type ModelRecord = Record - -const memoryCache = new NodeCache({ - stdTTL: 5 * 60, - checkperiod: 5 * 60, -}) +const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { const filename = `${router}_models.json` @@ -48,6 +42,7 @@ export const getModels = async (router: RouterName): Promise => { let models = memoryCache.get(router) if (models) { + // console.log(`[getModels] NodeCache hit for ${router} -> ${Object.keys(models).length}`) return models } @@ -67,10 +62,12 @@ export const getModels = async (router: RouterName): Promise => { } if (Object.keys(models).length > 0) { + // console.log(`[getModels] API fetch for ${router} -> ${Object.keys(models).length}`) memoryCache.set(router, models) try { await writeModels(router, models) + // console.log(`[getModels] wrote ${router} models to file cache`) } catch (error) {} return models @@ -78,6 +75,7 @@ export const getModels = async (router: RouterName): Promise => { try { models = await readModels(router) + // console.log(`[getModels] read ${router} models from file cache`) } catch (error) {} return models ?? {} diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index b1616cf47b..ed19c8496e 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -4,6 +4,7 @@ import OpenAI from "openai" import { ApiHandlerOptions, + ModelRecord, openRouterDefaultModelId, openRouterDefaultModelInfo, PROMPT_CACHING_MODELS, @@ -16,7 +17,7 @@ import { convertToR1Format } from "../transform/r1-format" import { getModelParams, SingleCompletionHandler } from "../index" import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants" import { BaseProvider } from "./base-provider" -import { ModelRecord, getModels } from "./fetchers/cache" +import { getModels } from "./fetchers/cache" const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]" @@ -130,8 +131,6 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } // https://openrouter.ai/docs/transforms - let fullResponseText = "" - const completionParams: OpenRouterChatCompletionParams = { model: modelId, max_tokens: maxTokens, @@ -170,7 +169,6 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } if (delta?.content) { - fullResponseText += delta.content yield { type: "text", text: delta.content } } diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index 118aaf1caf..9fe976bb51 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -1,11 +1,11 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { ModelInfo, requestyDefaultModelId, requestyDefaultModelInfo } from "../../shared/api" +import { ModelInfo, ModelRecord, requestyDefaultModelId, requestyDefaultModelInfo } from "../../shared/api" import { calculateApiCostOpenAI } from "../../utils/cost" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { OpenAiHandler, OpenAiHandlerOptions } from "./openai" -import { ModelRecord, getModels } from "./fetchers/cache" +import { getModels } from "./fetchers/cache" // Requesty usage includes an extra field for Anthropic use cases. // Safely cast the prompt token details section to the appropriate structure. diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts index 3c1c8e7a16..5b680b1b1d 100644 --- a/src/api/providers/router-provider.ts +++ b/src/api/providers/router-provider.ts @@ -1,8 +1,8 @@ import OpenAI from "openai" -import { ApiHandlerOptions, ModelInfo } from "../../shared/api" +import { ApiHandlerOptions, RouterName, ModelRecord, ModelInfo } from "../../shared/api" import { BaseProvider } from "./base-provider" -import { RouterName, ModelRecord, getModels } from "./fetchers/cache" +import { getModels } from "./fetchers/cache" type RouterProviderOptions = { name: RouterName diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 11366cfede..00b2cc5194 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -1,5 +1,6 @@ +import { GitCommit } from "../utils/git" + import { - ModelInfo, GlobalSettings, ApiConfigMeta, ProviderSettings as ApiConfiguration, @@ -13,8 +14,8 @@ import { ClineMessage, } from "../schemas" import { McpServer } from "./mcp" -import { GitCommit } from "../utils/git" import { Mode } from "./modes" +import { RouterModels } from "./api" export type { ApiConfigMeta, ToolProgressStatus } @@ -84,7 +85,7 @@ export interface ExtensionMessage { path?: string }> partialMessage?: ClineMessage - routerModels?: Record<"openrouter" | "requesty" | "glama" | "unbound", Record> + routerModels?: RouterModels openAiModels?: string[] ollamaModels?: string[] lmStudioModels?: string[] diff --git a/src/shared/api.ts b/src/shared/api.ts index b5d38e421b..2559232c11 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1446,3 +1446,13 @@ export const COMPUTER_USE_MODELS = new Set([ "anthropic/claude-3.7-sonnet:beta", "anthropic/claude-3.7-sonnet:thinking", ]) + +const routerNames = ["openrouter", "requesty", "glama", "unbound"] as const + +export type RouterName = (typeof routerNames)[number] + +export const isRouterName = (value: string): value is RouterName => routerNames.includes(value as RouterName) + +export type ModelRecord = Record + +export type RouterModels = Record diff --git a/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx b/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx index 9b2fc37d25..81f3dea1fd 100644 --- a/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx +++ b/webview-ui/src/components/settings/__tests__/SettingsView.test.tsx @@ -9,11 +9,7 @@ import { ExtensionStateContextProvider } from "@/context/ExtensionStateContext" import SettingsView from "../SettingsView" // Mock vscode API -jest.mock("@src/utils/vscode", () => ({ - vscode: { - postMessage: jest.fn(), - }, -})) +jest.mock("@src/utils/vscode", () => ({ vscode: { postMessage: jest.fn() } })) // Mock all lucide-react icons with a proxy to handle any icon requested jest.mock("lucide-react", () => { @@ -79,10 +75,10 @@ jest.mock("@vscode/webview-ui-toolkit/react", () => ({ /> ), VSCodeLink: ({ children, href }: any) => {children}, - VSCodeRadio: ({ children, value, checked, onChange }: any) => ( + VSCodeRadio: ({ value, checked, onChange }: any) => ( ), - VSCodeRadioGroup: ({ children, value, onChange }: any) =>
{children}
, + VSCodeRadioGroup: ({ children, onChange }: any) =>
{children}
, })) // Mock Slider component diff --git a/webview-ui/src/components/ui/hooks/useRouterModels.ts b/webview-ui/src/components/ui/hooks/useRouterModels.ts index 8140b8e533..56b2954db0 100644 --- a/webview-ui/src/components/ui/hooks/useRouterModels.ts +++ b/webview-ui/src/components/ui/hooks/useRouterModels.ts @@ -1,11 +1,9 @@ -import { ModelInfo } from "@roo/shared/api" +import { RouterModels } from "@roo/shared/api" import { vscode } from "@src/utils/vscode" import { ExtensionMessage } from "@roo/shared/ExtensionMessage" import { useQuery } from "@tanstack/react-query" -export type RouterModels = Record<"openrouter" | "requesty" | "glama" | "unbound", Record> - const getRouterModels = async () => new Promise((resolve, reject) => { const cleanup = () => { diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 11c018a43f..6b5b01b918 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -1,5 +1,6 @@ import { ApiConfiguration, + RouterModels, ModelInfo, anthropicDefaultModelId, anthropicModels, @@ -26,7 +27,7 @@ import { unboundDefaultModelId, } from "@roo/shared/api" -import { type RouterModels, useRouterModels } from "./useRouterModels" +import { useRouterModels } from "./useRouterModels" export const useSelectedModel = (apiConfiguration?: ApiConfiguration) => { const { data: routerModels, isLoading, isError } = useRouterModels() diff --git a/webview-ui/src/components/welcome/WelcomeView.tsx b/webview-ui/src/components/welcome/WelcomeView.tsx index 07871496a4..dd517f7b06 100644 --- a/webview-ui/src/components/welcome/WelcomeView.tsx +++ b/webview-ui/src/components/welcome/WelcomeView.tsx @@ -17,7 +17,7 @@ const WelcomeView = () => { const [errorMessage, setErrorMessage] = useState(undefined) const handleSubmit = useCallback(() => { - const error = validateApiConfiguration(apiConfiguration) + const error = apiConfiguration ? validateApiConfiguration(apiConfiguration) : undefined if (error) { setErrorMessage(error) diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 55fa77a268..5c072edc0d 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -1,11 +1,8 @@ -import { ApiConfiguration, ModelInfo } from "@roo/shared/api" import i18next from "i18next" -export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): string | undefined { - if (!apiConfiguration) { - return undefined - } +import { ApiConfiguration, isRouterName, RouterModels } from "@roo/shared/api" +export function validateApiConfiguration(apiConfiguration: ApiConfiguration): string | undefined { switch (apiConfiguration.apiProvider) { case "openrouter": if (!apiConfiguration.openRouterApiKey) { @@ -113,89 +110,41 @@ export function validateBedrockArn(arn: string, region?: string) { } // ARN is valid and region matches (or no region was provided to check against) - return { - isValid: true, - arnRegion, - errorMessage: undefined, - } + return { isValid: true, arnRegion, errorMessage: undefined } } -export function validateModelId( - apiConfiguration?: ApiConfiguration, - routerModels?: Record<"openrouter" | "glama" | "unbound" | "requesty", Record>, -): string | undefined { - if (!apiConfiguration) { +export function validateModelId(apiConfiguration: ApiConfiguration, routerModels?: RouterModels): string | undefined { + const provider = apiConfiguration.apiProvider ?? "" + + if (!isRouterName(provider)) { return undefined } - switch (apiConfiguration.apiProvider) { - case "openrouter": - const modelId = apiConfiguration.openRouterModelId - - if (!modelId) { - return i18next.t("settings:validation.modelId") - } - - if ( - routerModels?.openrouter && - Object.keys(routerModels.openrouter).length > 1 && - !Object.keys(routerModels.openrouter).includes(modelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId }) - } + let modelId: string | undefined + switch (provider) { + case "openrouter": + modelId = apiConfiguration.openRouterModelId break - case "glama": - const glamaModelId = apiConfiguration.glamaModelId - - if (!glamaModelId) { - return i18next.t("settings:validation.modelId") - } - - if ( - routerModels?.glama && - Object.keys(routerModels.glama).length > 1 && - !Object.keys(routerModels.glama).includes(glamaModelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId: glamaModelId }) - } - + modelId = apiConfiguration.glamaModelId break - case "unbound": - const unboundModelId = apiConfiguration.unboundModelId - - if (!unboundModelId) { - return i18next.t("settings:validation.modelId") - } - - if ( - routerModels?.unbound && - Object.keys(routerModels.unbound).length > 1 && - !Object.keys(routerModels.unbound).includes(unboundModelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId: unboundModelId }) - } - + modelId = apiConfiguration.unboundModelId break - case "requesty": - const requestyModelId = apiConfiguration.requestyModelId + modelId = apiConfiguration.requestyModelId + break + } - if (!requestyModelId) { - return i18next.t("settings:validation.modelId") - } + if (!modelId) { + return i18next.t("settings:validation.modelId") + } - if ( - routerModels?.requesty && - Object.keys(routerModels.requesty).length > 1 && - !Object.keys(routerModels.requesty).includes(requestyModelId) - ) { - return i18next.t("settings:validation.modelAvailability", { modelId: requestyModelId }) - } + const models = routerModels?.[provider] - break + if (models && Object.keys(models).length > 1 && !Object.keys(models).includes(modelId)) { + return i18next.t("settings:validation.modelAvailability", { modelId }) } return undefined