diff --git a/src/api/providers/fetchers/__tests__/litellm.test.ts b/src/api/providers/fetchers/__tests__/litellm.test.ts new file mode 100644 index 0000000000..4f59c3ab31 --- /dev/null +++ b/src/api/providers/fetchers/__tests__/litellm.test.ts @@ -0,0 +1,227 @@ +import axios from "axios" +import { getLiteLLMModels } from "../litellm" +import { COMPUTER_USE_MODELS } from "../../../../shared/api" + +// Mock axios +jest.mock("axios") +const mockedAxios = axios as jest.Mocked + +const DUMMY_INVALID_KEY = "invalid-key-for-testing" + +describe("getLiteLLMModels", () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it("successfully fetches and formats LiteLLM models", async () => { + const mockResponse = { + data: { + data: [ + { + model_name: "claude-3-5-sonnet", + model_info: { + max_tokens: 4096, + max_input_tokens: 200000, + supports_vision: true, + supports_prompt_caching: false, + input_cost_per_token: 0.000003, + output_cost_per_token: 0.000015, + }, + litellm_params: { + model: "anthropic/claude-3.5-sonnet", + }, + }, + { + model_name: "gpt-4-turbo", + model_info: { + max_tokens: 8192, + max_input_tokens: 128000, + supports_vision: false, + supports_prompt_caching: false, + input_cost_per_token: 0.00001, + output_cost_per_token: 0.00003, + }, + litellm_params: { + model: "openai/gpt-4-turbo", + }, + }, + ], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + const result = await getLiteLLMModels("test-api-key", "http://localhost:4000") + + expect(mockedAxios.get).toHaveBeenCalledWith("http://localhost:4000/v1/model/info", { + headers: { + Authorization: "Bearer test-api-key", + "Content-Type": "application/json", + }, + timeout: 5000, + }) + + expect(result).toEqual({ + "claude-3-5-sonnet": { + maxTokens: 4096, + contextWindow: 200000, + supportsImages: true, + supportsComputerUse: true, + supportsPromptCache: false, + inputPrice: 3, + outputPrice: 15, + description: "claude-3-5-sonnet via LiteLLM proxy", + }, + "gpt-4-turbo": { + maxTokens: 8192, + contextWindow: 128000, + supportsImages: false, + supportsComputerUse: false, + supportsPromptCache: false, + inputPrice: 10, + outputPrice: 30, + description: "gpt-4-turbo via LiteLLM proxy", + }, + }) + }) + + it("makes request without authorization header when no API key provided", async () => { + const mockResponse = { + data: { + data: [], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + await getLiteLLMModels("", "http://localhost:4000") + + expect(mockedAxios.get).toHaveBeenCalledWith("http://localhost:4000/v1/model/info", { + headers: { + "Content-Type": "application/json", + }, + timeout: 5000, + }) + }) + + it("handles computer use models correctly", async () => { + const computerUseModel = Array.from(COMPUTER_USE_MODELS)[0] + const mockResponse = { + data: { + data: [ + { + model_name: "test-computer-model", + model_info: { + max_tokens: 4096, + max_input_tokens: 200000, + supports_vision: true, + }, + litellm_params: { + model: `anthropic/${computerUseModel}`, + }, + }, + ], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + const result = await getLiteLLMModels("test-api-key", "http://localhost:4000") + + expect(result["test-computer-model"]).toEqual({ + maxTokens: 4096, + contextWindow: 200000, + supportsImages: true, + supportsComputerUse: true, + supportsPromptCache: false, + inputPrice: undefined, + outputPrice: undefined, + description: "test-computer-model via LiteLLM proxy", + }) + }) + + it("throws error for unexpected response format", async () => { + const mockResponse = { + data: { + // Missing 'data' field + models: [], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + await expect(getLiteLLMModels("test-api-key", "http://localhost:4000")).rejects.toThrow( + "Failed to fetch LiteLLM models: Unexpected response format.", + ) + }) + + it("throws detailed error for HTTP error responses", async () => { + const axiosError = { + response: { + status: 401, + statusText: "Unauthorized", + }, + isAxiosError: true, + } + + mockedAxios.isAxiosError.mockReturnValue(true) + mockedAxios.get.mockRejectedValue(axiosError) + + await expect(getLiteLLMModels(DUMMY_INVALID_KEY, "http://localhost:4000")).rejects.toThrow( + "Failed to fetch LiteLLM models: 401 Unauthorized. Check base URL and API key.", + ) + }) + + it("throws network error for request failures", async () => { + const axiosError = { + request: {}, + isAxiosError: true, + } + + mockedAxios.isAxiosError.mockReturnValue(true) + mockedAxios.get.mockRejectedValue(axiosError) + + await expect(getLiteLLMModels("test-api-key", "http://invalid-url")).rejects.toThrow( + "Failed to fetch LiteLLM models: No response from server. Check LiteLLM server status and base URL.", + ) + }) + + it("throws generic error for other failures", async () => { + const genericError = new Error("Network timeout") + + mockedAxios.isAxiosError.mockReturnValue(false) + mockedAxios.get.mockRejectedValue(genericError) + + await expect(getLiteLLMModels("test-api-key", "http://localhost:4000")).rejects.toThrow( + "Failed to fetch LiteLLM models: Network timeout", + ) + }) + + it("handles timeout parameter correctly", async () => { + const mockResponse = { data: { data: [] } } + mockedAxios.get.mockResolvedValue(mockResponse) + + await getLiteLLMModels("test-api-key", "http://localhost:4000") + + expect(mockedAxios.get).toHaveBeenCalledWith( + "http://localhost:4000/v1/model/info", + expect.objectContaining({ + timeout: 5000, + }), + ) + }) + + it("returns empty object when data array is empty", async () => { + const mockResponse = { + data: { + data: [], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + const result = await getLiteLLMModels("test-api-key", "http://localhost:4000") + + expect(result).toEqual({}) + }) +}) diff --git a/src/api/providers/fetchers/__tests__/modelCache.test.ts b/src/api/providers/fetchers/__tests__/modelCache.test.ts new file mode 100644 index 0000000000..abc477a8a5 --- /dev/null +++ b/src/api/providers/fetchers/__tests__/modelCache.test.ts @@ -0,0 +1,158 @@ +import { getModels } from "../modelCache" +import { getLiteLLMModels } from "../litellm" +import { getOpenRouterModels } from "../openrouter" +import { getRequestyModels } from "../requesty" +import { getGlamaModels } from "../glama" +import { getUnboundModels } from "../unbound" + +// Mock NodeCache to avoid cache interference +jest.mock("node-cache", () => { + return jest.fn().mockImplementation(() => ({ + get: jest.fn().mockReturnValue(undefined), // Always return cache miss + set: jest.fn(), + del: jest.fn(), + })) +}) + +// Mock fs/promises to avoid file system operations +jest.mock("fs/promises", () => ({ + writeFile: jest.fn().mockResolvedValue(undefined), + readFile: jest.fn().mockResolvedValue("{}"), + mkdir: jest.fn().mockResolvedValue(undefined), +})) + +// Mock all the model fetchers +jest.mock("../litellm") +jest.mock("../openrouter") +jest.mock("../requesty") +jest.mock("../glama") +jest.mock("../unbound") + +const mockGetLiteLLMModels = getLiteLLMModels as jest.MockedFunction +const mockGetOpenRouterModels = getOpenRouterModels as jest.MockedFunction +const mockGetRequestyModels = getRequestyModels as jest.MockedFunction +const mockGetGlamaModels = getGlamaModels as jest.MockedFunction +const mockGetUnboundModels = getUnboundModels as jest.MockedFunction + +const DUMMY_REQUESTY_KEY = "requesty-key-for-testing" +const DUMMY_UNBOUND_KEY = "unbound-key-for-testing" + +describe("getModels with new GetModelsOptions", () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it("calls getLiteLLMModels with correct parameters", async () => { + const mockModels = { + "claude-3-sonnet": { + maxTokens: 4096, + contextWindow: 200000, + supportsPromptCache: false, + description: "Claude 3 Sonnet via LiteLLM", + }, + } + mockGetLiteLLMModels.mockResolvedValue(mockModels) + + const result = await getModels({ + provider: "litellm", + apiKey: "test-api-key", + baseUrl: "http://localhost:4000", + }) + + expect(mockGetLiteLLMModels).toHaveBeenCalledWith("test-api-key", "http://localhost:4000") + expect(result).toEqual(mockModels) + }) + + it("calls getOpenRouterModels for openrouter provider", async () => { + const mockModels = { + "openrouter/model": { + maxTokens: 8192, + contextWindow: 128000, + supportsPromptCache: false, + description: "OpenRouter model", + }, + } + mockGetOpenRouterModels.mockResolvedValue(mockModels) + + const result = await getModels({ provider: "openrouter" }) + + expect(mockGetOpenRouterModels).toHaveBeenCalled() + expect(result).toEqual(mockModels) + }) + + it("calls getRequestyModels with optional API key", async () => { + const mockModels = { + "requesty/model": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Requesty model", + }, + } + mockGetRequestyModels.mockResolvedValue(mockModels) + + const result = await getModels({ provider: "requesty", apiKey: DUMMY_REQUESTY_KEY }) + + expect(mockGetRequestyModels).toHaveBeenCalledWith(DUMMY_REQUESTY_KEY) + expect(result).toEqual(mockModels) + }) + + it("calls getGlamaModels for glama provider", async () => { + const mockModels = { + "glama/model": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Glama model", + }, + } + mockGetGlamaModels.mockResolvedValue(mockModels) + + const result = await getModels({ provider: "glama" }) + + expect(mockGetGlamaModels).toHaveBeenCalled() + expect(result).toEqual(mockModels) + }) + + it("calls getUnboundModels with optional API key", async () => { + const mockModels = { + "unbound/model": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Unbound model", + }, + } + mockGetUnboundModels.mockResolvedValue(mockModels) + + const result = await getModels({ provider: "unbound", apiKey: DUMMY_UNBOUND_KEY }) + + expect(mockGetUnboundModels).toHaveBeenCalledWith(DUMMY_UNBOUND_KEY) + expect(result).toEqual(mockModels) + }) + + it("handles errors and re-throws them", async () => { + const expectedError = new Error("LiteLLM connection failed") + mockGetLiteLLMModels.mockRejectedValue(expectedError) + + await expect( + getModels({ + provider: "litellm", + apiKey: "test-api-key", + baseUrl: "http://localhost:4000", + }), + ).rejects.toThrow("LiteLLM connection failed") + }) + + it("validates exhaustive provider checking with unknown provider", async () => { + // This test ensures TypeScript catches unknown providers at compile time + // In practice, the discriminated union should prevent this at compile time + const unknownProvider = "unknown" as any + + await expect( + getModels({ + provider: unknownProvider, + }), + ).rejects.toThrow("Unknown provider: unknown") + }) +}) diff --git a/src/api/providers/fetchers/litellm.ts b/src/api/providers/fetchers/litellm.ts index ac143b8acb..713237a627 100644 --- a/src/api/providers/fetchers/litellm.ts +++ b/src/api/providers/fetchers/litellm.ts @@ -7,6 +7,7 @@ import { COMPUTER_USE_MODELS, ModelRecord } from "../../../shared/api" * @param apiKey The API key for the LiteLLM server * @param baseUrl The base URL of the LiteLLM server * @returns A promise that resolves to a record of model IDs to model info + * @throws Will throw an error if the request fails or the response is not as expected. */ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise { try { @@ -17,8 +18,8 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise if (apiKey) { headers["Authorization"] = `Bearer ${apiKey}` } - - const response = await axios.get(`${baseUrl}/v1/model/info`, { headers }) + // Added timeout to prevent indefinite hanging + const response = await axios.get(`${baseUrl}/v1/model/info`, { headers, timeout: 5000 }) const models: ModelRecord = {} const computerModels = Array.from(COMPUTER_USE_MODELS) @@ -48,11 +49,25 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise description: `${modelName} via LiteLLM proxy`, } } + } else { + // If response.data.data is not in the expected format, consider it an error. + console.error("Error fetching LiteLLM models: Unexpected response format", response.data) + throw new Error("Failed to fetch LiteLLM models: Unexpected response format.") } return models - } catch (error) { - console.error("Error fetching LiteLLM models:", error) - return {} + } catch (error: any) { + console.error("Error fetching LiteLLM models:", error.message ? error.message : error) + if (axios.isAxiosError(error) && error.response) { + throw new Error( + `Failed to fetch LiteLLM models: ${error.response.status} ${error.response.statusText}. Check base URL and API key.`, + ) + } else if (axios.isAxiosError(error) && error.request) { + throw new Error( + "Failed to fetch LiteLLM models: No response from server. Check LiteLLM server status and base URL.", + ) + } else { + throw new Error(`Failed to fetch LiteLLM models: ${error.message || "An unknown error occurred."}`) + } } } diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 9517bda8ae..12d636bc46 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -13,7 +13,7 @@ import { getRequestyModels } from "./requesty" import { getGlamaModels } from "./glama" import { getUnboundModels } from "./unbound" import { getLiteLLMModels } from "./litellm" - +import { GetModelsOptions } from "../../../shared/api" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { @@ -41,64 +41,59 @@ async function readModels(router: RouterName): Promise * @param baseUrl - Optional base URL for the provider (currently used only for LiteLLM). * @returns The models from the cache or the fetched models. */ -export const getModels = async ( - router: RouterName, - apiKey: string | undefined = undefined, - baseUrl: string | undefined = undefined, -): Promise => { - let models = memoryCache.get(router) - +export const getModels = async (options: GetModelsOptions): Promise => { + const { provider } = options + let models = memoryCache.get(provider) if (models) { - // console.log(`[getModels] NodeCache hit for ${router} -> ${Object.keys(models).length}`) return models } - switch (router) { - case "openrouter": - models = await getOpenRouterModels() - break - case "requesty": - // Requesty models endpoint requires an API key for per-user custom policies - models = await getRequestyModels(apiKey) - break - case "glama": - models = await getGlamaModels() - break - case "unbound": - // Unbound models endpoint requires an API key to fetch application specific models - models = await getUnboundModels(apiKey) - break - case "litellm": - if (apiKey && baseUrl) { - models = await getLiteLLMModels(apiKey, baseUrl) - } else { - models = {} + try { + switch (provider) { + case "openrouter": + models = await getOpenRouterModels() + break + case "requesty": + // Requesty models endpoint requires an API key for per-user custom policies + models = await getRequestyModels(options.apiKey) + break + case "glama": + models = await getGlamaModels() + break + case "unbound": + // Unbound models endpoint requires an API key to fetch application specific models + models = await getUnboundModels(options.apiKey) + break + case "litellm": + // Type safety ensures apiKey and baseUrl are always provided for litellm + models = await getLiteLLMModels(options.apiKey, options.baseUrl) + break + default: { + // Ensures router is exhaustively checked if RouterName is a strict union + const exhaustiveCheck: never = provider + throw new Error(`Unknown provider: ${exhaustiveCheck}`) } - break - } + } - if (Object.keys(models).length > 0) { - // console.log(`[getModels] API fetch for ${router} -> ${Object.keys(models).length}`) - memoryCache.set(router, models) + // Cache the fetched models (even if empty, to signify a successful fetch with no models) + memoryCache.set(provider, models) + await writeModels(provider, models).catch((err) => + console.error(`[getModels] Error writing ${provider} models to file cache:`, err), + ) try { - await writeModels(router, models) - // console.log(`[getModels] wrote ${router} models to file cache`) + models = await readModels(provider) + // console.log(`[getModels] read ${router} models from file cache`) } catch (error) { - console.error(`[getModels] error writing ${router} models to file cache`, error) + console.error(`[getModels] error reading ${provider} models from file cache`, error) } - - return models - } - - try { - models = await readModels(router) - // console.log(`[getModels] read ${router} models from file cache`) + return models || {} } catch (error) { - console.error(`[getModels] error reading ${router} models from file cache`, error) - } + // Log the error and re-throw it so the caller can handle it (e.g., show a UI message). + console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error) - return models ?? {} + throw error // Re-throw the original error to be handled by the caller. + } } /** diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index c08faeb117..5d1682f821 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -162,7 +162,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH public async fetchModel() { const [models, endpoints] = await Promise.all([ - getModels("openrouter"), + getModels({ provider: "openrouter" }), getModelEndpoints({ router: "openrouter", modelId: this.options.openRouterModelId, diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index fe8bba7e6e..c2e0a12bdd 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -45,7 +45,7 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan } public async fetchModel() { - this.models = await getModels("requesty") + this.models = await getModels({ provider: "requesty" }) return this.getModel() } diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts index a0decdcab4..30093be9b8 100644 --- a/src/api/providers/router-provider.ts +++ b/src/api/providers/router-provider.ts @@ -44,7 +44,7 @@ export abstract class RouterProvider extends BaseProvider { } public async fetchModel() { - this.models = await getModels(this.name, this.client.apiKey, this.client.baseURL) + this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL }) return this.getModel() } diff --git a/src/core/webview/__tests__/ClineProvider.test.ts b/src/core/webview/__tests__/ClineProvider.test.ts index 0c7b309615..5af4a476b7 100644 --- a/src/core/webview/__tests__/ClineProvider.test.ts +++ b/src/core/webview/__tests__/ClineProvider.test.ts @@ -2128,3 +2128,261 @@ describe("getTelemetryProperties", () => { expect(properties).toHaveProperty("modelId", "claude-3-7-sonnet-20250219") }) }) + +// Mock getModels for router model tests +jest.mock("../../../api/providers/fetchers/modelCache", () => ({ + getModels: jest.fn(), + flushModels: jest.fn(), +})) + +describe("ClineProvider - Router Models", () => { + let provider: ClineProvider + let mockContext: vscode.ExtensionContext + let mockOutputChannel: vscode.OutputChannel + let mockWebviewView: vscode.WebviewView + let mockPostMessage: jest.Mock + + beforeEach(() => { + jest.clearAllMocks() + + const globalState: Record = {} + const secrets: Record = {} + + mockContext = { + extensionPath: "/test/path", + extensionUri: {} as vscode.Uri, + globalState: { + get: jest.fn().mockImplementation((key: string) => globalState[key]), + update: jest + .fn() + .mockImplementation((key: string, value: string | undefined) => (globalState[key] = value)), + keys: jest.fn().mockImplementation(() => Object.keys(globalState)), + }, + secrets: { + get: jest.fn().mockImplementation((key: string) => secrets[key]), + store: jest.fn().mockImplementation((key: string, value: string | undefined) => (secrets[key] = value)), + delete: jest.fn().mockImplementation((key: string) => delete secrets[key]), + }, + subscriptions: [], + extension: { + packageJSON: { version: "1.0.0" }, + }, + globalStorageUri: { + fsPath: "/test/storage/path", + }, + } as unknown as vscode.ExtensionContext + + mockOutputChannel = { + appendLine: jest.fn(), + clear: jest.fn(), + dispose: jest.fn(), + } as unknown as vscode.OutputChannel + + mockPostMessage = jest.fn() + mockWebviewView = { + webview: { + postMessage: mockPostMessage, + html: "", + options: {}, + onDidReceiveMessage: jest.fn(), + asWebviewUri: jest.fn(), + }, + visible: true, + onDidDispose: jest.fn().mockImplementation((callback) => { + callback() + return { dispose: jest.fn() } + }), + onDidChangeVisibility: jest.fn().mockImplementation(() => ({ dispose: jest.fn() })), + } as unknown as vscode.WebviewView + + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) + }) + + test("handles requestRouterModels with successful responses", async () => { + await provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Mock getState to return API configuration + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + openRouterApiKey: "openrouter-key", + requestyApiKey: "requesty-key", + glamaApiKey: "glama-key", + unboundApiKey: "unbound-key", + litellmApiKey: "litellm-key", + litellmBaseUrl: "http://localhost:4000", + }, + } as any) + + const mockModels = { + "model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model 1" }, + "model-2": { maxTokens: 8192, contextWindow: 16384, description: "Test model 2" }, + } + + const { getModels } = require("../../../api/providers/fetchers/modelCache") + getModels.mockResolvedValue(mockModels) + + await messageHandler({ type: "requestRouterModels" }) + + // Verify getModels was called for each provider with correct options + expect(getModels).toHaveBeenCalledWith({ provider: "openrouter" }) + expect(getModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" }) + expect(getModels).toHaveBeenCalledWith({ provider: "glama" }) + expect(getModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" }) + expect(getModels).toHaveBeenCalledWith({ + provider: "litellm", + apiKey: "litellm-key", + baseUrl: "http://localhost:4000", + }) + + // Verify response was sent + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "routerModels", + routerModels: { + openrouter: mockModels, + requesty: mockModels, + glama: mockModels, + unbound: mockModels, + litellm: mockModels, + }, + }) + }) + + test("handles requestRouterModels with individual provider failures", async () => { + await provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + openRouterApiKey: "openrouter-key", + requestyApiKey: "requesty-key", + glamaApiKey: "glama-key", + unboundApiKey: "unbound-key", + litellmApiKey: "litellm-key", + litellmBaseUrl: "http://localhost:4000", + }, + } as any) + + const mockModels = { "model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model" } } + const { getModels } = require("../../../api/providers/fetchers/modelCache") + + // Mock some providers to succeed and others to fail + getModels + .mockResolvedValueOnce(mockModels) // openrouter success + .mockRejectedValueOnce(new Error("Requesty API error")) // requesty fail + .mockResolvedValueOnce(mockModels) // glama success + .mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail + .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail + + await messageHandler({ type: "requestRouterModels" }) + + // Verify main response includes successful providers and empty objects for failed ones + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "routerModels", + routerModels: { + openrouter: mockModels, + requesty: {}, + glama: mockModels, + unbound: {}, + litellm: {}, + }, + }) + + // Verify error messages were sent for failed providers + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Requesty API error", + values: { provider: "requesty" }, + }) + + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Unbound API error", + values: { provider: "unbound" }, + }) + + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "LiteLLM connection failed", + values: { provider: "litellm" }, + }) + }) + + test("handles requestRouterModels with LiteLLM values from message", async () => { + await provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + // Mock state without LiteLLM config + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + openRouterApiKey: "openrouter-key", + requestyApiKey: "requesty-key", + glamaApiKey: "glama-key", + unboundApiKey: "unbound-key", + // No litellm config + }, + } as any) + + const mockModels = { "model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model" } } + const { getModels } = require("../../../api/providers/fetchers/modelCache") + getModels.mockResolvedValue(mockModels) + + await messageHandler({ + type: "requestRouterModels", + values: { + litellmApiKey: "message-litellm-key", + litellmBaseUrl: "http://message-url:4000", + }, + }) + + // Verify LiteLLM was called with values from message + expect(getModels).toHaveBeenCalledWith({ + provider: "litellm", + apiKey: "message-litellm-key", + baseUrl: "http://message-url:4000", + }) + }) + + test("skips LiteLLM when neither config nor message values are provided", async () => { + await provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as jest.Mock).mock.calls[0][0] + + jest.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + openRouterApiKey: "openrouter-key", + requestyApiKey: "requesty-key", + glamaApiKey: "glama-key", + unboundApiKey: "unbound-key", + // No litellm config + }, + } as any) + + const mockModels = { "model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model" } } + const { getModels } = require("../../../api/providers/fetchers/modelCache") + getModels.mockResolvedValue(mockModels) + + await messageHandler({ type: "requestRouterModels" }) + + // Verify LiteLLM was NOT called + expect(getModels).not.toHaveBeenCalledWith( + expect.objectContaining({ + provider: "litellm", + }), + ) + + // Verify response includes empty object for LiteLLM + expect(mockPostMessage).toHaveBeenCalledWith({ + type: "routerModels", + routerModels: { + openrouter: mockModels, + requesty: mockModels, + glama: mockModels, + unbound: mockModels, + litellm: {}, + }, + }) + }) +}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.test.ts b/src/core/webview/__tests__/webviewMessageHandler.test.ts new file mode 100644 index 0000000000..7f3bc49654 --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.test.ts @@ -0,0 +1,274 @@ +import { webviewMessageHandler } from "../webviewMessageHandler" +import { ClineProvider } from "../ClineProvider" +import { getModels } from "../../../api/providers/fetchers/modelCache" +import { ModelRecord } from "../../../shared/api" + +// Mock dependencies +jest.mock("../../../api/providers/fetchers/modelCache") +const mockGetModels = getModels as jest.MockedFunction + +// Mock ClineProvider +const mockClineProvider = { + getState: jest.fn(), + postMessageToWebview: jest.fn(), +} as unknown as ClineProvider + +describe("webviewMessageHandler - requestRouterModels", () => { + beforeEach(() => { + jest.clearAllMocks() + mockClineProvider.getState = jest.fn().mockResolvedValue({ + apiConfiguration: { + openRouterApiKey: "openrouter-key", + requestyApiKey: "requesty-key", + glamaApiKey: "glama-key", + unboundApiKey: "unbound-key", + litellmApiKey: "litellm-key", + litellmBaseUrl: "http://localhost:4000", + }, + }) + }) + + it("successfully fetches models from all providers", async () => { + const mockModels: ModelRecord = { + "model-1": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Test model 1", + }, + "model-2": { + maxTokens: 8192, + contextWindow: 16384, + supportsPromptCache: false, + description: "Test model 2", + }, + } + + mockGetModels.mockResolvedValue(mockModels) + + await webviewMessageHandler(mockClineProvider, { + type: "requestRouterModels", + }) + + // Verify getModels was called for each provider + expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "glama" }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" }) + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "litellm", + apiKey: "litellm-key", + baseUrl: "http://localhost:4000", + }) + + // Verify response was sent + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "routerModels", + routerModels: { + openrouter: mockModels, + requesty: mockModels, + glama: mockModels, + unbound: mockModels, + litellm: mockModels, + }, + }) + }) + + it("handles LiteLLM models with values from message when config is missing", async () => { + mockClineProvider.getState = jest.fn().mockResolvedValue({ + apiConfiguration: { + openRouterApiKey: "openrouter-key", + requestyApiKey: "requesty-key", + glamaApiKey: "glama-key", + unboundApiKey: "unbound-key", + // Missing litellm config + }, + }) + + const mockModels: ModelRecord = { + "model-1": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Test model 1", + }, + } + + mockGetModels.mockResolvedValue(mockModels) + + await webviewMessageHandler(mockClineProvider, { + type: "requestRouterModels", + values: { + litellmApiKey: "message-litellm-key", + litellmBaseUrl: "http://message-url:4000", + }, + }) + + // Verify LiteLLM was called with values from message + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "litellm", + apiKey: "message-litellm-key", + baseUrl: "http://message-url:4000", + }) + }) + + it("skips LiteLLM when both config and message values are missing", async () => { + mockClineProvider.getState = jest.fn().mockResolvedValue({ + apiConfiguration: { + openRouterApiKey: "openrouter-key", + requestyApiKey: "requesty-key", + glamaApiKey: "glama-key", + unboundApiKey: "unbound-key", + // Missing litellm config + }, + }) + + const mockModels: ModelRecord = { + "model-1": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Test model 1", + }, + } + + mockGetModels.mockResolvedValue(mockModels) + + await webviewMessageHandler(mockClineProvider, { + type: "requestRouterModels", + // No values provided + }) + + // Verify LiteLLM was NOT called + expect(mockGetModels).not.toHaveBeenCalledWith( + expect.objectContaining({ + provider: "litellm", + }), + ) + + // Verify response includes empty object for LiteLLM + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "routerModels", + routerModels: { + openrouter: mockModels, + requesty: mockModels, + glama: mockModels, + unbound: mockModels, + litellm: {}, + }, + }) + }) + + it("handles individual provider failures gracefully", async () => { + const mockModels: ModelRecord = { + "model-1": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Test model 1", + }, + } + + // Mock some providers to succeed and others to fail + mockGetModels + .mockResolvedValueOnce(mockModels) // openrouter + .mockRejectedValueOnce(new Error("Requesty API error")) // requesty + .mockResolvedValueOnce(mockModels) // glama + .mockRejectedValueOnce(new Error("Unbound API error")) // unbound + .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm + + await webviewMessageHandler(mockClineProvider, { + type: "requestRouterModels", + }) + + // Verify successful providers are included + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "routerModels", + routerModels: { + openrouter: mockModels, + requesty: {}, + glama: mockModels, + unbound: {}, + litellm: {}, + }, + }) + + // Verify error messages were sent for failed providers + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Requesty API error", + values: { provider: "requesty" }, + }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Unbound API error", + values: { provider: "unbound" }, + }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "LiteLLM connection failed", + values: { provider: "litellm" }, + }) + }) + + it("handles Error objects and string errors correctly", async () => { + // Mock providers to fail with different error types + mockGetModels + .mockRejectedValueOnce(new Error("Structured error message")) // Error object + .mockRejectedValueOnce("String error message") // String error + .mockRejectedValueOnce({ message: "Object with message" }) // Object error + .mockResolvedValueOnce({}) // Success + .mockResolvedValueOnce({}) // Success + + await webviewMessageHandler(mockClineProvider, { + type: "requestRouterModels", + }) + + // Verify error handling for different error types + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Structured error message", + values: { provider: "openrouter" }, + }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "String error message", + values: { provider: "requesty" }, + }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "[object Object]", + values: { provider: "glama" }, + }) + }) + + it("prefers config values over message values for LiteLLM", async () => { + const mockModels: ModelRecord = {} + mockGetModels.mockResolvedValue(mockModels) + + await webviewMessageHandler(mockClineProvider, { + type: "requestRouterModels", + values: { + litellmApiKey: "message-key", + litellmBaseUrl: "http://message-url", + }, + }) + + // Verify config values are used over message values + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "litellm", + apiKey: "litellm-key", // From config + baseUrl: "http://localhost:4000", // From config + }) + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 41232c8680..7c4e906849 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -6,7 +6,7 @@ import * as vscode from "vscode" import { ClineProvider } from "./ClineProvider" import { Language, ProviderSettings, GlobalState, Package } from "../../schemas" import { changeLanguage, t } from "../../i18n" -import { RouterName, toRouterName } from "../../shared/api" +import { RouterName, toRouterName, ModelRecord } from "../../shared/api" import { supportPrompt } from "../../shared/support-prompt" import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage" import { checkExistKey } from "../../shared/checkExistApiConfig" @@ -32,6 +32,7 @@ import { TelemetrySetting } from "../../shared/TelemetrySetting" import { getWorkspacePath } from "../../utils/path" import { Mode, defaultModeSlug } from "../../shared/modes" import { getModels, flushModels } from "../../api/providers/fetchers/modelCache" +import { GetModelsOptions } from "../../shared/api" import { generateSystemPrompt } from "./generateSystemPrompt" import { getCommand } from "../../utils/commands" @@ -282,29 +283,81 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We await provider.resetState() break case "flushRouterModels": - const routerName: RouterName = toRouterName(message.text) - await flushModels(routerName) + const routerNameFlush: RouterName = toRouterName(message.text) + await flushModels(routerNameFlush) break case "requestRouterModels": const { apiConfiguration } = await provider.getState() - const [openRouterModels, requestyModels, glamaModels, unboundModels, litellmModels] = await Promise.all([ - getModels("openrouter", apiConfiguration.openRouterApiKey), - getModels("requesty", apiConfiguration.requestyApiKey), - getModels("glama", apiConfiguration.glamaApiKey), - getModels("unbound", apiConfiguration.unboundApiKey), - getModels("litellm", apiConfiguration.litellmApiKey, apiConfiguration.litellmBaseUrl), - ]) + const routerModels: Partial> = { + openrouter: {}, + requesty: {}, + glama: {}, + unbound: {}, + litellm: {}, + } + + const safeGetModels = async (options: GetModelsOptions): Promise => { + try { + return await getModels(options) + } catch (error) { + console.error( + `Failed to fetch models in webviewMessageHandler requestRouterModels for ${options.provider}:`, + error, + ) + throw error // Re-throw to be caught by Promise.allSettled + } + } + + const modelFetchPromises: Array<{ key: RouterName; options: GetModelsOptions }> = [ + { key: "openrouter", options: { provider: "openrouter" } }, + { key: "requesty", options: { provider: "requesty", apiKey: apiConfiguration.requestyApiKey } }, + { key: "glama", options: { provider: "glama" } }, + { key: "unbound", options: { provider: "unbound", apiKey: apiConfiguration.unboundApiKey } }, + ] + + const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey + const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl + if (litellmApiKey && litellmBaseUrl) { + modelFetchPromises.push({ + key: "litellm", + options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, + }) + } + + const results = await Promise.allSettled( + modelFetchPromises.map(async ({ key, options }) => { + const models = await safeGetModels(options) + return { key, models } // key is RouterName here + }), + ) + + const fetchedRouterModels: Partial> = { ...routerModels } + + results.forEach((result, index) => { + const routerName = modelFetchPromises[index].key // Get RouterName using index + + if (result.status === "fulfilled") { + fetchedRouterModels[routerName] = result.value.models + } else { + // Handle rejection: Post a specific error message for this provider + const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason) + console.error(`Error fetching models for ${routerName}:`, result.reason) + + fetchedRouterModels[routerName] = {} // Ensure it's an empty object in the main routerModels message + + provider.postMessageToWebview({ + type: "singleRouterModelFetchResponse", + success: false, + error: errorMessage, + values: { provider: routerName }, + }) + } + }) provider.postMessageToWebview({ type: "routerModels", - routerModels: { - openrouter: openRouterModels, - requesty: requestyModels, - glama: glamaModels, - unbound: unboundModels, - litellm: litellmModels, - }, + routerModels: fetchedRouterModels as Record, }) break case "requestOpenAiModels": diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 2227847224..8870dcaee0 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -70,6 +70,7 @@ export interface ExtensionMessage { | "commandExecutionStatus" | "vsCodeSetting" | "condenseTaskContextResponse" + | "singleRouterModelFetchResponse" text?: string action?: | "chatButtonClicked" diff --git a/src/shared/api.ts b/src/shared/api.ts index 92cc406854..6f63471788 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -1245,7 +1245,7 @@ export const unboundDefaultModelInfo: ModelInfo = { // LiteLLM // https://docs.litellm.ai/ -export const litellmDefaultModelId = "anthropic/claude-3-7-sonnet-20250219" +export const litellmDefaultModelId = "claude-3-7-sonnet-20250219" export const litellmDefaultModelInfo: ModelInfo = { maxTokens: 8192, contextWindow: 200_000, @@ -1960,3 +1960,15 @@ export const getModelMaxOutputTokens = ({ return model.maxTokens ?? undefined } + +/** + * Options for fetching models from different providers. + * This is a discriminated union type where the provider property determines + * which other properties are required. + */ +export type GetModelsOptions = + | { provider: "openrouter" } + | { provider: "glama" } + | { provider: "requesty"; apiKey?: string } + | { provider: "unbound"; apiKey?: string } + | { provider: "litellm"; apiKey: string; baseUrl: string } diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 378119e70c..20d10cf459 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -390,11 +390,7 @@ const ApiOptions = ({ )} {selectedProvider === "litellm" && ( - + )} {selectedProvider === "human-relay" && ( diff --git a/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx b/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx index 255f66c2e5..19d707a009 100644 --- a/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx +++ b/webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx @@ -146,6 +146,29 @@ jest.mock("../DiffSettingsControl", () => ({ ), })) +// Mock LiteLLM provider for tests +jest.mock("../providers/LiteLLM", () => ({ + LiteLLM: ({ apiConfiguration, setApiConfigurationField }: any) => ( +
+ setApiConfigurationField("litellmBaseUrl", e.target.value)} + placeholder="Base URL" + /> + setApiConfigurationField("litellmApiKey", e.target.value)} + placeholder="API Key" + /> + +
+ ), +})) + jest.mock("@src/components/ui/hooks/useSelectedModel", () => ({ useSelectedModel: jest.fn((apiConfiguration: ProviderSettings) => { if (apiConfiguration.apiModelId?.includes("thinking")) { @@ -388,4 +411,61 @@ describe("ApiOptions", () => { ) }) }) + + describe("LiteLLM provider tests", () => { + it("renders LiteLLM component when provider is selected", () => { + renderApiOptions({ + apiConfiguration: { + apiProvider: "litellm", + litellmBaseUrl: "http://localhost:4000", + litellmApiKey: "test-key", + }, + }) + + expect(screen.getByTestId("litellm-provider")).toBeInTheDocument() + expect(screen.getByTestId("litellm-base-url")).toHaveValue("http://localhost:4000") + expect(screen.getByTestId("litellm-api-key")).toHaveValue("test-key") + }) + + it("calls setApiConfigurationField when LiteLLM inputs change", () => { + const mockSetApiConfigurationField = jest.fn() + renderApiOptions({ + apiConfiguration: { + apiProvider: "litellm", + }, + setApiConfigurationField: mockSetApiConfigurationField, + }) + + const baseUrlInput = screen.getByTestId("litellm-base-url") + const apiKeyInput = screen.getByTestId("litellm-api-key") + + fireEvent.change(baseUrlInput, { target: { value: "http://new-url:8000" } }) + fireEvent.change(apiKeyInput, { target: { value: "new-api-key" } }) + + expect(mockSetApiConfigurationField).toHaveBeenCalledWith("litellmBaseUrl", "http://new-url:8000") + expect(mockSetApiConfigurationField).toHaveBeenCalledWith("litellmApiKey", "new-api-key") + }) + + it("shows refresh models button for LiteLLM", () => { + renderApiOptions({ + apiConfiguration: { + apiProvider: "litellm", + litellmBaseUrl: "http://localhost:4000", + litellmApiKey: "test-key", + }, + }) + + expect(screen.getByTestId("litellm-refresh-models")).toBeInTheDocument() + }) + + it("does not render LiteLLM component when other provider is selected", () => { + renderApiOptions({ + apiConfiguration: { + apiProvider: "anthropic", + }, + }) + + expect(screen.queryByTestId("litellm-provider")).not.toBeInTheDocument() + }) + }) }) diff --git a/webview-ui/src/components/settings/providers/LiteLLM.tsx b/webview-ui/src/components/settings/providers/LiteLLM.tsx index 8f98d97e5d..717b10bb8a 100644 --- a/webview-ui/src/components/settings/providers/LiteLLM.tsx +++ b/webview-ui/src/components/settings/providers/LiteLLM.tsx @@ -1,21 +1,56 @@ -import { useCallback } from "react" +import { useCallback, useState, useEffect, useRef } from "react" import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" -import { ProviderSettings, RouterModels, litellmDefaultModelId } from "@roo/shared/api" +import { ProviderSettings, litellmDefaultModelId, RouterName } from "@roo/shared/api" +import { Button } from "@src/components/ui" +import { vscode } from "@src/utils/vscode" +import { ExtensionMessage } from "@roo/shared/ExtensionMessage" import { useAppTranslation } from "@src/i18n/TranslationContext" import { inputEventTransform } from "../transforms" import { ModelPicker } from "../ModelPicker" +import { useExtensionState } from "@src/context/ExtensionStateContext" type LiteLLMProps = { apiConfiguration: ProviderSettings setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void - routerModels?: RouterModels } -export const LiteLLM = ({ apiConfiguration, setApiConfigurationField, routerModels }: LiteLLMProps) => { +export const LiteLLM = ({ apiConfiguration, setApiConfigurationField }: LiteLLMProps) => { const { t } = useAppTranslation() + const { routerModels } = useExtensionState() + const [refreshStatus, setRefreshStatus] = useState<"idle" | "loading" | "success" | "error">("idle") + const [refreshError, setRefreshError] = useState() + const litellmErrorJustReceived = useRef(false) + + useEffect(() => { + const handleMessage = (event: MessageEvent) => { + const message = event.data + if (message.type === "singleRouterModelFetchResponse" && !message.success) { + const providerName = message.values?.provider as RouterName + if (providerName === "litellm") { + litellmErrorJustReceived.current = true + setRefreshStatus("error") + setRefreshError(message.error) + } + } else if (message.type === "routerModels") { + // If we were loading and no specific error for litellm was just received, mark as success. + // The ModelPicker will show available models or "no models found". + if (refreshStatus === "loading") { + if (!litellmErrorJustReceived.current) { + setRefreshStatus("success") + } + // If litellmErrorJustReceived.current is true, status is already (or will be) "error". + } + } + } + + window.addEventListener("message", handleMessage) + return () => { + window.removeEventListener("message", handleMessage) + } + }, [refreshStatus, refreshError, setRefreshStatus, setRefreshError]) const handleInputChange = useCallback( ( @@ -28,12 +63,28 @@ export const LiteLLM = ({ apiConfiguration, setApiConfigurationField, routerMode [setApiConfigurationField], ) + const handleRefreshModels = useCallback(() => { + litellmErrorJustReceived.current = false // Reset flag on new refresh action + setRefreshStatus("loading") + setRefreshError(undefined) + + const key = apiConfiguration.litellmApiKey + const url = apiConfiguration.litellmBaseUrl + + if (!key || !url) { + setRefreshStatus("error") + setRefreshError(t("settings:providers.refreshModels.missingConfig")) + return + } + vscode.postMessage({ type: "requestRouterModels", values: { litellmApiKey: key, litellmBaseUrl: url } }) + }, [apiConfiguration, setRefreshStatus, setRefreshError, t]) + return ( <> @@ -51,6 +102,35 @@ export const LiteLLM = ({ apiConfiguration, setApiConfigurationField, routerMode {t("settings:providers.apiKeyStorageNotice")} + + {refreshStatus === "loading" && ( +
+ {t("settings:providers.refreshModels.loading")} +
+ )} + {refreshStatus === "success" && ( +
{t("settings:providers.refreshModels.success")}
+ )} + {refreshStatus === "error" && ( +
+ {refreshError || t("settings:providers.refreshModels.error")} +
+ )} { const { apiConfiguration, currentApiConfigName, setApiConfiguration, uriScheme, machineId } = useExtensionState() const { t } = useAppTranslation() const [errorMessage, setErrorMessage] = useState(undefined) + // Memoize the setApiConfigurationField function to pass to ApiOptions + const setApiConfigurationFieldForApiOptions = useCallback( + (field: K, value: ProviderSettings[K]) => { + setApiConfiguration({ [field]: value }) + }, + [setApiConfiguration], // setApiConfiguration from context is stable + ) + const handleSubmit = useCallback(() => { const error = apiConfiguration ? validateApiConfiguration(apiConfiguration) : undefined @@ -106,7 +115,7 @@ const WelcomeView = () => { fromWelcomeView apiConfiguration={apiConfiguration || {}} uriScheme={uriScheme} - setApiConfigurationField={(field, value) => setApiConfiguration({ [field]: value })} + setApiConfigurationField={setApiConfigurationFieldForApiOptions} errorMessage={errorMessage} setErrorMessage={setErrorMessage} /> diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index 11c7c9c0d4..2ac1326189 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -10,6 +10,7 @@ import { Mode, CustomModePrompts, defaultModeSlug, defaultPrompts, ModeConfig } import { CustomSupportPrompts } from "@roo/shared/support-prompt" import { experimentDefault, ExperimentId } from "@roo/shared/experiments" import { TelemetrySetting } from "@roo/shared/TelemetrySetting" +import { RouterModels } from "@roo/shared/api" import { vscode } from "@src/utils/vscode" import { convertTextMateToHljs } from "@src/utils/textMateToHljs" @@ -102,6 +103,7 @@ export interface ExtensionStateContextType extends ExtensionState { setHistoryPreviewCollapsed: (value: boolean) => void autoCondenseContextPercent: number setAutoCondenseContextPercent: (value: number) => void + routerModels?: RouterModels } export const ExtensionStateContext = createContext(undefined) @@ -192,12 +194,23 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode const [openedTabs, setOpenedTabs] = useState>([]) const [mcpServers, setMcpServers] = useState([]) const [currentCheckpoint, setCurrentCheckpoint] = useState() + const [extensionRouterModels, setExtensionRouterModels] = useState(undefined) const setListApiConfigMeta = useCallback( (value: ProviderSettingsEntry[]) => setState((prevState) => ({ ...prevState, listApiConfigMeta: value })), [], ) + const setApiConfiguration = useCallback((value: ProviderSettings) => { + setState((prevState) => ({ + ...prevState, + apiConfiguration: { + ...prevState.apiConfiguration, + ...value, + }, + })) + }, []) + const handleMessage = useCallback( (event: MessageEvent) => { const message: ExtensionMessage = event.data @@ -249,6 +262,10 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode setListApiConfigMeta(message.listApiConfig ?? []) break } + case "routerModels": { + setExtensionRouterModels(message.routerModels) + break + } } }, [setListApiConfigMeta], @@ -274,16 +291,10 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode fuzzyMatchThreshold: state.fuzzyMatchThreshold, writeDelayMs: state.writeDelayMs, screenshotQuality: state.screenshotQuality, + routerModels: extensionRouterModels, setExperimentEnabled: (id, enabled) => setState((prevState) => ({ ...prevState, experiments: { ...prevState.experiments, [id]: enabled } })), - setApiConfiguration: (value) => - setState((prevState) => ({ - ...prevState, - apiConfiguration: { - ...prevState.apiConfiguration, - ...value, - }, - })), + setApiConfiguration, setCustomInstructions: (value) => setState((prevState) => ({ ...prevState, customInstructions: value })), setAlwaysAllowReadOnly: (value) => setState((prevState) => ({ ...prevState, alwaysAllowReadOnly: value })), setAlwaysAllowReadOnlyOutsideWorkspace: (value) => @@ -358,7 +369,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode return { ...prevState, pinnedApiConfigs: newPinned } }), setHistoryPreviewCollapsed: (value) => - setState((prevState) => ({ ...prevState, historyPreviewCollapsed: value })), // Implement the setter + setState((prevState) => ({ ...prevState, historyPreviewCollapsed: value })), setAutoCondenseContextPercent: (value) => setState((prevState) => ({ ...prevState, autoCondenseContextPercent: value })), setCondensingApiConfigId: (value) => setState((prevState) => ({ ...prevState, condensingApiConfigId: value })), diff --git a/webview-ui/src/i18n/locales/ca/settings.json b/webview-ui/src/i18n/locales/ca/settings.json index 10951d9ab2..e252e1ba9c 100644 --- a/webview-ui/src/i18n/locales/ca/settings.json +++ b/webview-ui/src/i18n/locales/ca/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Clau API de Requesty", "refreshModels": { "label": "Actualitzar models", - "hint": "Si us plau, torneu a obrir la configuració per veure els models més recents." + "hint": "Si us plau, torneu a obrir la configuració per veure els models més recents.", + "loading": "Actualitzant la llista de models...", + "success": "Llista de models actualitzada correctament!", + "error": "No s'ha pogut actualitzar la llista de models. Si us plau, torneu-ho a provar." }, "getRequestyApiKey": "Obtenir clau API de Requesty", "openRouterTransformsText": "Comprimir prompts i cadenes de missatges a la mida del context (Transformacions d'OpenRouter)", diff --git a/webview-ui/src/i18n/locales/de/settings.json b/webview-ui/src/i18n/locales/de/settings.json index d7d6061928..57307ab6ef 100644 --- a/webview-ui/src/i18n/locales/de/settings.json +++ b/webview-ui/src/i18n/locales/de/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API-Schlüssel", "refreshModels": { "label": "Modelle aktualisieren", - "hint": "Bitte öffne die Einstellungen erneut, um die neuesten Modelle zu sehen." + "hint": "Bitte öffne die Einstellungen erneut, um die neuesten Modelle zu sehen.", + "loading": "Modellliste wird aktualisiert...", + "success": "Modellliste erfolgreich aktualisiert!", + "error": "Fehler beim Aktualisieren der Modellliste. Bitte versuche es erneut." }, "getRequestyApiKey": "Requesty API-Schlüssel erhalten", "openRouterTransformsText": "Prompts und Nachrichtenketten auf Kontextgröße komprimieren (OpenRouter Transformationen)", diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index 913f9b8d4f..e284d1a557 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API Key", "refreshModels": { "label": "Refresh Models", - "hint": "Please reopen the settings to see the latest models." + "hint": "Please reopen the settings to see the latest models.", + "loading": "Refreshing models list...", + "success": "Models list refreshed successfully!", + "error": "Failed to refresh models list. Please try again." }, "getRequestyApiKey": "Get Requesty API Key", "openRouterTransformsText": "Compress prompts and message chains to the context size (OpenRouter Transforms)", diff --git a/webview-ui/src/i18n/locales/es/settings.json b/webview-ui/src/i18n/locales/es/settings.json index 13d1714a78..85dea9070f 100644 --- a/webview-ui/src/i18n/locales/es/settings.json +++ b/webview-ui/src/i18n/locales/es/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Clave API de Requesty", "refreshModels": { "label": "Actualizar modelos", - "hint": "Por favor, vuelve a abrir la configuración para ver los modelos más recientes." + "hint": "Por favor, vuelve a abrir la configuración para ver los modelos más recientes.", + "loading": "Actualizando lista de modelos...", + "success": "¡Lista de modelos actualizada correctamente!", + "error": "Error al actualizar la lista de modelos. Por favor, inténtalo de nuevo." }, "getRequestyApiKey": "Obtener clave API de Requesty", "openRouterTransformsText": "Comprimir prompts y cadenas de mensajes al tamaño del contexto (Transformaciones de OpenRouter)", diff --git a/webview-ui/src/i18n/locales/fr/settings.json b/webview-ui/src/i18n/locales/fr/settings.json index d0c3a5abac..76ddb98230 100644 --- a/webview-ui/src/i18n/locales/fr/settings.json +++ b/webview-ui/src/i18n/locales/fr/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Clé API Requesty", "refreshModels": { "label": "Actualiser les modèles", - "hint": "Veuillez rouvrir les paramètres pour voir les modèles les plus récents." + "hint": "Veuillez rouvrir les paramètres pour voir les modèles les plus récents.", + "loading": "Actualisation de la liste des modèles...", + "success": "Liste des modèles actualisée avec succès !", + "error": "Échec de l'actualisation de la liste des modèles. Veuillez réessayer." }, "getRequestyApiKey": "Obtenir la clé API Requesty", "openRouterTransformsText": "Compresser les prompts et chaînes de messages à la taille du contexte (Transformations OpenRouter)", diff --git a/webview-ui/src/i18n/locales/hi/settings.json b/webview-ui/src/i18n/locales/hi/settings.json index f133a74a7d..ee33485921 100644 --- a/webview-ui/src/i18n/locales/hi/settings.json +++ b/webview-ui/src/i18n/locales/hi/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API कुंजी", "refreshModels": { "label": "मॉडल रिफ्रेश करें", - "hint": "नवीनतम मॉडल देखने के लिए कृपया सेटिंग्स को फिर से खोलें।" + "hint": "नवीनतम मॉडल देखने के लिए कृपया सेटिंग्स को फिर से खोलें।", + "loading": "मॉडल सूची अपडेट हो रही है...", + "success": "मॉडल सूची सफलतापूर्वक अपडेट की गई!", + "error": "मॉडल सूची अपडेट करने में विफल। कृपया पुनः प्रयास करें।" }, "getRequestyApiKey": "Requesty API कुंजी प्राप्त करें", "openRouterTransformsText": "संदर्भ आकार के लिए प्रॉम्प्ट और संदेश श्रृंखलाओं को संपीड़ित करें (OpenRouter ट्रांसफॉर्म)", diff --git a/webview-ui/src/i18n/locales/it/settings.json b/webview-ui/src/i18n/locales/it/settings.json index 4d034d08a5..03b10281de 100644 --- a/webview-ui/src/i18n/locales/it/settings.json +++ b/webview-ui/src/i18n/locales/it/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Chiave API Requesty", "refreshModels": { "label": "Aggiorna modelli", - "hint": "Riapri le impostazioni per vedere i modelli più recenti." + "hint": "Riapri le impostazioni per vedere i modelli più recenti.", + "loading": "Aggiornamento dell'elenco dei modelli...", + "success": "Elenco dei modelli aggiornato con successo!", + "error": "Impossibile aggiornare l'elenco dei modelli. Riprova." }, "getRequestyApiKey": "Ottieni chiave API Requesty", "openRouterTransformsText": "Comprimi prompt e catene di messaggi alla dimensione del contesto (Trasformazioni OpenRouter)", diff --git a/webview-ui/src/i18n/locales/ja/settings.json b/webview-ui/src/i18n/locales/ja/settings.json index 8ba05935f1..b4b4191b1a 100644 --- a/webview-ui/src/i18n/locales/ja/settings.json +++ b/webview-ui/src/i18n/locales/ja/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty APIキー", "refreshModels": { "label": "モデルを更新", - "hint": "最新のモデルを表示するには設定を再度開いてください。" + "hint": "最新のモデルを表示するには設定を再度開いてください。", + "loading": "モデルリストを更新中...", + "success": "モデルリストが正常に更新されました!", + "error": "モデルリストの更新に失敗しました。もう一度お試しください。" }, "getRequestyApiKey": "Requesty APIキーを取得", "openRouterTransformsText": "プロンプトとメッセージチェーンをコンテキストサイズに圧縮 (OpenRouter Transforms)", diff --git a/webview-ui/src/i18n/locales/ko/settings.json b/webview-ui/src/i18n/locales/ko/settings.json index 41db7df978..d95c7b7ea9 100644 --- a/webview-ui/src/i18n/locales/ko/settings.json +++ b/webview-ui/src/i18n/locales/ko/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API 키", "refreshModels": { "label": "모델 새로고침", - "hint": "최신 모델을 보려면 설정을 다시 열어주세요." + "hint": "최신 모델을 보려면 설정을 다시 열어주세요.", + "loading": "모델 목록 새로고침 중...", + "success": "모델 목록이 성공적으로 새로고침되었습니다!", + "error": "모델 목록 새로고침에 실패했습니다. 다시 시도해 주세요." }, "getRequestyApiKey": "Requesty API 키 받기", "openRouterTransformsText": "프롬프트와 메시지 체인을 컨텍스트 크기로 압축 (OpenRouter Transforms)", diff --git a/webview-ui/src/i18n/locales/nl/settings.json b/webview-ui/src/i18n/locales/nl/settings.json index 2bb6535660..dedae007ee 100644 --- a/webview-ui/src/i18n/locales/nl/settings.json +++ b/webview-ui/src/i18n/locales/nl/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API-sleutel", "refreshModels": { "label": "Modellen verversen", - "hint": "Open de instellingen opnieuw om de nieuwste modellen te zien." + "hint": "Open de instellingen opnieuw om de nieuwste modellen te zien.", + "loading": "Modellenlijst wordt vernieuwd...", + "success": "Modellenlijst succesvol vernieuwd!", + "error": "Kan modellenlijst niet vernieuwen. Probeer het opnieuw." }, "getRequestyApiKey": "Requesty API-sleutel ophalen", "openRouterTransformsText": "Comprimeer prompts en berichtreeksen tot de contextgrootte (OpenRouter Transforms)", diff --git a/webview-ui/src/i18n/locales/pl/settings.json b/webview-ui/src/i18n/locales/pl/settings.json index 088271d7e4..3695a89545 100644 --- a/webview-ui/src/i18n/locales/pl/settings.json +++ b/webview-ui/src/i18n/locales/pl/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Klucz API Requesty", "refreshModels": { "label": "Odśwież modele", - "hint": "Proszę ponownie otworzyć ustawienia, aby zobaczyć najnowsze modele." + "hint": "Proszę ponownie otworzyć ustawienia, aby zobaczyć najnowsze modele.", + "loading": "Odświeżanie listy modeli...", + "success": "Lista modeli została pomyślnie odświeżona!", + "error": "Nie udało się odświeżyć listy modeli. Spróbuj ponownie." }, "getRequestyApiKey": "Uzyskaj klucz API Requesty", "openRouterTransformsText": "Kompresuj podpowiedzi i łańcuchy wiadomości do rozmiaru kontekstu (Transformacje OpenRouter)", diff --git a/webview-ui/src/i18n/locales/pt-BR/settings.json b/webview-ui/src/i18n/locales/pt-BR/settings.json index 3a804c1f57..0657c5f8ce 100644 --- a/webview-ui/src/i18n/locales/pt-BR/settings.json +++ b/webview-ui/src/i18n/locales/pt-BR/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Chave de API Requesty", "refreshModels": { "label": "Atualizar modelos", - "hint": "Por favor, reabra as configurações para ver os modelos mais recentes." + "hint": "Por favor, reabra as configurações para ver os modelos mais recentes.", + "loading": "Atualizando lista de modelos...", + "success": "Lista de modelos atualizada com sucesso!", + "error": "Falha ao atualizar a lista de modelos. Por favor, tente novamente." }, "getRequestyApiKey": "Obter chave de API Requesty", "openRouterTransformsText": "Comprimir prompts e cadeias de mensagens para o tamanho do contexto (Transformações OpenRouter)", diff --git a/webview-ui/src/i18n/locales/ru/settings.json b/webview-ui/src/i18n/locales/ru/settings.json index 32321ed089..87002bcd1f 100644 --- a/webview-ui/src/i18n/locales/ru/settings.json +++ b/webview-ui/src/i18n/locales/ru/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API-ключ", "refreshModels": { "label": "Обновить модели", - "hint": "Пожалуйста, откройте настройки заново, чтобы увидеть последние модели." + "hint": "Пожалуйста, откройте настройки заново, чтобы увидеть последние модели.", + "loading": "Обновление списка моделей...", + "success": "Список моделей успешно обновлен!", + "error": "Не удалось обновить список моделей. Пожалуйста, попробуйте снова." }, "getRequestyApiKey": "Получить Requesty API-ключ", "openRouterTransformsText": "Сжимать подсказки и цепочки сообщений до размера контекста (OpenRouter Transforms)", diff --git a/webview-ui/src/i18n/locales/tr/settings.json b/webview-ui/src/i18n/locales/tr/settings.json index cbfc604288..a89c139899 100644 --- a/webview-ui/src/i18n/locales/tr/settings.json +++ b/webview-ui/src/i18n/locales/tr/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API Anahtarı", "refreshModels": { "label": "Modelleri Yenile", - "hint": "En son modelleri görmek için lütfen ayarları yeniden açın." + "hint": "En son modelleri görmek için lütfen ayarları yeniden açın.", + "loading": "Model listesi yenileniyor...", + "success": "Model listesi başarıyla yenilendi!", + "error": "Model listesi yenilenemedi. Lütfen tekrar deneyin." }, "getRequestyApiKey": "Requesty API Anahtarı Al", "openRouterTransformsText": "İstem ve mesaj zincirlerini bağlam boyutuna sıkıştır (OpenRouter Dönüşümleri)", diff --git a/webview-ui/src/i18n/locales/vi/settings.json b/webview-ui/src/i18n/locales/vi/settings.json index 5c43dddd25..3cef200f7c 100644 --- a/webview-ui/src/i18n/locales/vi/settings.json +++ b/webview-ui/src/i18n/locales/vi/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Khóa API Requesty", "refreshModels": { "label": "Làm mới mô hình", - "hint": "Vui lòng mở lại cài đặt để xem các mô hình mới nhất." + "hint": "Vui lòng mở lại cài đặt để xem các mô hình mới nhất.", + "loading": "Đang làm mới danh sách mô hình...", + "success": "Danh sách mô hình đã được làm mới thành công!", + "error": "Không thể làm mới danh sách mô hình. Vui lòng thử lại." }, "getRequestyApiKey": "Lấy khóa API Requesty", "openRouterTransformsText": "Nén lời nhắc và chuỗi tin nhắn theo kích thước ngữ cảnh (OpenRouter Transforms)", diff --git a/webview-ui/src/i18n/locales/zh-CN/settings.json b/webview-ui/src/i18n/locales/zh-CN/settings.json index a036ba7477..6958c5f5cc 100644 --- a/webview-ui/src/i18n/locales/zh-CN/settings.json +++ b/webview-ui/src/i18n/locales/zh-CN/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API 密钥", "refreshModels": { "label": "刷新模型", - "hint": "请重新打开设置以查看最新模型。" + "hint": "请重新打开设置以查看最新模型。", + "loading": "正在刷新模型列表...", + "success": "模型列表刷新成功!", + "error": "刷新模型列表失败。请重试。" }, "getRequestyApiKey": "获取 Requesty API 密钥", "openRouterTransformsText": "自动压缩提示词和消息链到上下文长度限制内 (OpenRouter转换)", diff --git a/webview-ui/src/i18n/locales/zh-TW/settings.json b/webview-ui/src/i18n/locales/zh-TW/settings.json index f4e30f799c..3b1a5fe619 100644 --- a/webview-ui/src/i18n/locales/zh-TW/settings.json +++ b/webview-ui/src/i18n/locales/zh-TW/settings.json @@ -124,7 +124,10 @@ "requestyApiKey": "Requesty API 金鑰", "refreshModels": { "label": "重新整理模型", - "hint": "請重新開啟設定以查看最新模型。" + "hint": "請重新開啟設定以查看最新模型。", + "loading": "正在重新整理模型列表...", + "success": "模型列表重新整理成功!", + "error": "重新整理模型列表失敗。請再試一次。" }, "getRequestyApiKey": "取得 Requesty API 金鑰", "openRouterTransformsText": "將提示和訊息鏈壓縮到上下文大小 (OpenRouter 轉換)",