diff --git a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts new file mode 100644 index 000000000000..7954dc14a26a --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts @@ -0,0 +1,167 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { webviewMessageHandler } from "../webviewMessageHandler" +import type { ClineProvider } from "../ClineProvider" + +// Mock vscode (minimal) +vi.mock("vscode", () => ({ + window: { + showErrorMessage: vi.fn(), + showWarningMessage: vi.fn(), + showInformationMessage: vi.fn(), + }, + workspace: { + workspaceFolders: undefined, + getConfiguration: vi.fn(() => ({ + get: vi.fn(), + update: vi.fn(), + })), + }, + env: { + clipboard: { writeText: vi.fn() }, + openExternal: vi.fn(), + }, + commands: { + executeCommand: vi.fn(), + }, + Uri: { + parse: vi.fn((s: string) => ({ toString: () => s })), + file: vi.fn((p: string) => ({ fsPath: p })), + }, + ConfigurationTarget: { + Global: 1, + Workspace: 2, + WorkspaceFolder: 3, + }, +})) + +// Mock modelCache getModels/flushModels used by the handler +const getModelsMock = vi.fn() +vi.mock("../../../api/providers/fetchers/modelCache", () => ({ + getModels: (...args: any[]) => getModelsMock(...args), + flushModels: vi.fn(), +})) + +describe("webviewMessageHandler - requestRouterModels providers filter", () => { + let mockProvider: ClineProvider & { + postMessageToWebview: ReturnType + getState: ReturnType + contextProxy: any + log: ReturnType + } + + beforeEach(() => { + vi.clearAllMocks() + + mockProvider = { + // Only methods used by this code path + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }), + contextProxy: { + getValue: vi.fn(), + setValue: vi.fn(), + globalStorageUri: { fsPath: "/mock/storage" }, + }, + log: vi.fn(), + } as any + + // Default mock: return distinct model maps per provider so we can verify keys + getModelsMock.mockImplementation(async (options: any) => { + switch (options?.provider) { + case "roo": + return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } } + case "openrouter": + return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } } + case "requesty": + return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } } + case "deepinfra": + return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } } + case "glama": + return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } } + case "unbound": + return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } } + case "vercel-ai-gateway": + return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } } + case "io-intelligence": + return { "io/model": { contextWindow: 8192, supportsPromptCache: false } } + case "litellm": + return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } } + default: + return {} + } + }) + }) + + it("fetches only requested provider when values.providers is present (['roo'])", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { providers: ["roo"] }, + } as any, + ) + + // Should post a single routerModels message + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith( + expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }), + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const payload = call[0] + const routerModels = payload.routerModels as Record> + + // Only "roo" key should be present + const keys = Object.keys(routerModels) + expect(keys).toEqual(["roo"]) + expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet") + + // getModels should have been called exactly once for roo + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["roo"]) + }) + + it("defaults to aggregate fetching when no providers filter is sent", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + + // Aggregate handler initializes many known routers - ensure a few expected keys exist + expect(routerModels).toHaveProperty("openrouter") + expect(routerModels).toHaveProperty("roo") + expect(routerModels).toHaveProperty("requesty") + }) + + it("supports filtering another single provider (['openrouter'])", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { providers: ["openrouter"] }, + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + const keys = Object.keys(routerModels) + + expect(keys).toEqual(["openrouter"]) + expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5") + + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["openrouter"]) + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index e32b818a96e9..4a1495935542 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -757,20 +757,38 @@ export const webviewMessageHandler = async ( case "requestRouterModels": const { apiConfiguration } = await provider.getState() - const routerModels: Record = { - openrouter: {}, - "vercel-ai-gateway": {}, - huggingface: {}, - litellm: {}, - deepinfra: {}, - "io-intelligence": {}, - requesty: {}, - unbound: {}, - glama: {}, - ollama: {}, - lmstudio: {}, - roo: {}, - } + // Optional providers filter coming from the webview + const providersFilterRaw = Array.isArray(message?.values?.providers) ? message.values.providers : undefined + const requestedProviders = providersFilterRaw + ?.filter((p: unknown) => typeof p === "string") + .map((p: string) => { + try { + return toRouterName(p) + } catch { + return undefined + } + }) + .filter((p): p is RouterName => !!p) + + const hasFilter = !!requestedProviders && requestedProviders.length > 0 + const requestedSet = new Set(requestedProviders || []) + + const routerModels: Record = hasFilter + ? ({} as Record) + : { + openrouter: {}, + "vercel-ai-gateway": {}, + huggingface: {}, + litellm: {}, + deepinfra: {}, + "io-intelligence": {}, + requesty: {}, + unbound: {}, + glama: {}, + ollama: {}, + lmstudio: {}, + roo: {}, + } const safeGetModels = async (options: GetModelsOptions): Promise => { try { @@ -785,7 +803,8 @@ export const webviewMessageHandler = async ( } } - const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [ + // Base candidates (only those handled by this aggregate fetcher) + const candidates: { key: RouterName; options: GetModelsOptions }[] = [ { key: "openrouter", options: { provider: "openrouter" } }, { key: "requesty", @@ -818,29 +837,28 @@ export const webviewMessageHandler = async ( }, ] - // Add IO Intelligence if API key is provided. - const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey - - if (ioIntelligenceApiKey) { - modelFetchPromises.push({ + // IO Intelligence is conditional on api key + if (apiConfiguration.ioIntelligenceApiKey) { + candidates.push({ key: "io-intelligence", - options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey }, + options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey }, }) } - // Don't fetch Ollama and LM Studio models by default anymore. - // They have their own specific handlers: requestOllamaModels and requestLmStudioModels. - + // LiteLLM is conditional on baseUrl+apiKey const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl if (litellmApiKey && litellmBaseUrl) { - modelFetchPromises.push({ + candidates.push({ key: "litellm", options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, }) } + // Apply providers filter (if any) + const modelFetchPromises = candidates.filter(({ key }) => (!hasFilter ? true : requestedSet.has(key))) + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) @@ -854,18 +872,7 @@ export const webviewMessageHandler = async ( if (result.status === "fulfilled") { routerModels[routerName] = result.value.models - // Ollama and LM Studio settings pages still need these events. - if (routerName === "ollama" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "ollamaModels", - ollamaModels: result.value.models, - }) - } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "lmStudioModels", - lmStudioModels: result.value.models, - }) - } + // Ollama and LM Studio settings pages still need these events. They are not fetched here. } else { // Handle rejection: Post a specific error message for this provider. const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)