diff --git a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts new file mode 100644 index 000000000000..746a013f0f1d --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts @@ -0,0 +1,167 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { webviewMessageHandler } from "../webviewMessageHandler" +import type { ClineProvider } from "../ClineProvider" + +// Mock vscode (minimal) +vi.mock("vscode", () => ({ + window: { + showErrorMessage: vi.fn(), + showWarningMessage: vi.fn(), + showInformationMessage: vi.fn(), + }, + workspace: { + workspaceFolders: undefined, + getConfiguration: vi.fn(() => ({ + get: vi.fn(), + update: vi.fn(), + })), + }, + env: { + clipboard: { writeText: vi.fn() }, + openExternal: vi.fn(), + }, + commands: { + executeCommand: vi.fn(), + }, + Uri: { + parse: vi.fn((s: string) => ({ toString: () => s })), + file: vi.fn((p: string) => ({ fsPath: p })), + }, + ConfigurationTarget: { + Global: 1, + Workspace: 2, + WorkspaceFolder: 3, + }, +})) + +// Mock modelCache getModels/flushModels used by the handler +const getModelsMock = vi.fn() +vi.mock("../../../api/providers/fetchers/modelCache", () => ({ + getModels: (...args: any[]) => getModelsMock(...args), + flushModels: vi.fn(), +})) + +describe("webviewMessageHandler - requestRouterModels provider filter", () => { + let mockProvider: ClineProvider & { + postMessageToWebview: ReturnType + getState: ReturnType + contextProxy: any + log: ReturnType + } + + beforeEach(() => { + vi.clearAllMocks() + + mockProvider = { + // Only methods used by this code path + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }), + contextProxy: { + getValue: vi.fn(), + setValue: vi.fn(), + globalStorageUri: { fsPath: "/mock/storage" }, + }, + log: vi.fn(), + } as any + + // Default mock: return distinct model maps per provider so we can verify keys + getModelsMock.mockImplementation(async (options: any) => { + switch (options?.provider) { + case "roo": + return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } } + case "openrouter": + return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } } + case "requesty": + return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } } + case "deepinfra": + return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } } + case "glama": + return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } } + case "unbound": + return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } } + case "vercel-ai-gateway": + return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } } + case "io-intelligence": + return { "io/model": { contextWindow: 8192, supportsPromptCache: false } } + case "litellm": + return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } } + default: + return {} + } + }) + }) + + it("fetches only requested provider when values.provider is present ('roo')", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { provider: "roo" }, + } as any, + ) + + // Should post a single routerModels message + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith( + expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }), + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const payload = call[0] + const routerModels = payload.routerModels as Record> + + // Only "roo" key should be present + const keys = Object.keys(routerModels) + expect(keys).toEqual(["roo"]) + expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet") + + // getModels should have been called exactly once for roo + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["roo"]) + }) + + it("defaults to aggregate fetching when no provider filter is sent", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + + // Aggregate handler initializes many known routers - ensure a few expected keys exist + expect(routerModels).toHaveProperty("openrouter") + expect(routerModels).toHaveProperty("roo") + expect(routerModels).toHaveProperty("requesty") + }) + + it("supports filtering another single provider ('openrouter')", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { provider: "openrouter" }, + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + const keys = Object.keys(routerModels) + + expect(keys).toEqual(["openrouter"]) + expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5") + + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["openrouter"]) + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index c409f15a65d5..6c52b5ee2899 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -757,20 +757,26 @@ export const webviewMessageHandler = async ( case "requestRouterModels": const { apiConfiguration } = await provider.getState() - const routerModels: Record = { - openrouter: {}, - "vercel-ai-gateway": {}, - huggingface: {}, - litellm: {}, - deepinfra: {}, - "io-intelligence": {}, - requesty: {}, - unbound: {}, - glama: {}, - ollama: {}, - lmstudio: {}, - roo: {}, - } + // Optional single provider filter from webview + const requestedProvider = message?.values?.provider + const providerFilter = requestedProvider ? toRouterName(requestedProvider) : undefined + + const routerModels: Record = providerFilter + ? ({} as Record) + : { + openrouter: {}, + "vercel-ai-gateway": {}, + huggingface: {}, + litellm: {}, + deepinfra: {}, + "io-intelligence": {}, + requesty: {}, + unbound: {}, + glama: {}, + ollama: {}, + lmstudio: {}, + roo: {}, + } const safeGetModels = async (options: GetModelsOptions): Promise => { try { @@ -785,7 +791,8 @@ export const webviewMessageHandler = async ( } } - const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [ + // Base candidates (only those handled by this aggregate fetcher) + const candidates: { key: RouterName; options: GetModelsOptions }[] = [ { key: "openrouter", options: { provider: "openrouter" } }, { key: "requesty", @@ -818,29 +825,30 @@ export const webviewMessageHandler = async ( }, ] - // Add IO Intelligence if API key is provided. - const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey - - if (ioIntelligenceApiKey) { - modelFetchPromises.push({ + // IO Intelligence is conditional on api key + if (apiConfiguration.ioIntelligenceApiKey) { + candidates.push({ key: "io-intelligence", - options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey }, + options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey }, }) } - // Don't fetch Ollama and LM Studio models by default anymore. - // They have their own specific handlers: requestOllamaModels and requestLmStudioModels. - + // LiteLLM is conditional on baseUrl+apiKey const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl if (litellmApiKey && litellmBaseUrl) { - modelFetchPromises.push({ + candidates.push({ key: "litellm", options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, }) } + // Apply single provider filter if specified + const modelFetchPromises = providerFilter + ? candidates.filter(({ key }) => key === providerFilter) + : candidates + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) @@ -854,18 +862,7 @@ export const webviewMessageHandler = async ( if (result.status === "fulfilled") { routerModels[routerName] = result.value.models - // Ollama and LM Studio settings pages still need these events. - if (routerName === "ollama" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "ollamaModels", - ollamaModels: result.value.models, - }) - } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "lmStudioModels", - lmStudioModels: result.value.models, - }) - } + // Ollama and LM Studio settings pages still need these events. They are not fetched here. } else { // Handle rejection: Post a specific error message for this provider. const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason) @@ -882,7 +879,11 @@ export const webviewMessageHandler = async ( } }) - provider.postMessageToWebview({ type: "routerModels", routerModels }) + provider.postMessageToWebview({ + type: "routerModels", + routerModels, + values: providerFilter ? { provider: requestedProvider } : undefined, + }) break case "requestOllamaModels": { // Specific handler for Ollama models only. diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index e49944a9975c..c8afe3a46272 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -291,7 +291,7 @@ describe("useSelectedModel", () => { }) describe("loading and error states", () => { - it("should return loading state when router models are loading", () => { + it("should NOT set loading when router models are loading but provider is static (anthropic)", () => { mockUseRouterModels.mockReturnValue({ data: undefined, isLoading: true, @@ -307,10 +307,11 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(), { wrapper }) - expect(result.current.isLoading).toBe(true) + // With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false + expect(result.current.isLoading).toBe(false) }) - it("should return loading state when open router model providers are loading", () => { + it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => { mockUseRouterModels.mockReturnValue({ data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} }, isLoading: false, @@ -326,10 +327,11 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(), { wrapper }) - expect(result.current.isLoading).toBe(true) + // With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false + expect(result.current.isLoading).toBe(false) }) - it("should return error state when either hook has an error", () => { + it("should NOT set error when hooks error but provider is static (anthropic)", () => { mockUseRouterModels.mockReturnValue({ data: undefined, isLoading: false, @@ -345,7 +347,8 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(), { wrapper }) - expect(result.current.isError).toBe(true) + // Error from gated routerModels should not bubble for static provider default + expect(result.current.isError).toBe(false) }) }) diff --git a/webview-ui/src/components/ui/hooks/useRouterModels.ts b/webview-ui/src/components/ui/hooks/useRouterModels.ts index 0ca68cc27a6e..2527168bfd7f 100644 --- a/webview-ui/src/components/ui/hooks/useRouterModels.ts +++ b/webview-ui/src/components/ui/hooks/useRouterModels.ts @@ -5,7 +5,12 @@ import { ExtensionMessage } from "@roo/ExtensionMessage" import { vscode } from "@src/utils/vscode" -const getRouterModels = async () => +type UseRouterModelsOptions = { + provider?: string // single provider filter (e.g. "roo") + enabled?: boolean // gate fetching entirely +} + +const getRouterModels = async (provider?: string) => new Promise((resolve, reject) => { const cleanup = () => { window.removeEventListener("message", handler) @@ -20,6 +25,14 @@ const getRouterModels = async () => const message: ExtensionMessage = event.data if (message.type === "routerModels") { + const msgProvider = message?.values?.provider as string | undefined + + // Verify response matches request + if (provider !== msgProvider) { + // Not our response; ignore and wait for the matching one + return + } + clearTimeout(timeout) cleanup() @@ -32,7 +45,18 @@ const getRouterModels = async () => } window.addEventListener("message", handler) - vscode.postMessage({ type: "requestRouterModels" }) + if (provider) { + vscode.postMessage({ type: "requestRouterModels", values: { provider } }) + } else { + vscode.postMessage({ type: "requestRouterModels" }) + } }) -export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels }) +export const useRouterModels = (opts: UseRouterModelsOptions = {}) => { + const provider = opts.provider || undefined + return useQuery({ + queryKey: ["routerModels", provider || "all"], + queryFn: () => getRouterModels(provider), + enabled: opts.enabled !== false, + }) +} diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 55fdd120bd3c..1bfb2ea332d4 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -58,6 +58,7 @@ import { vercelAiGatewayDefaultModelId, BEDROCK_1M_CONTEXT_MODEL_IDS, deepInfraDefaultModelId, + isDynamicProvider, } from "@roo-code/types" import type { ModelRecord, RouterModels } from "@roo/api" @@ -73,24 +74,38 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined - const routerModels = useRouterModels() + // Only fetch router models for dynamic providers + const shouldFetchRouterModels = isDynamicProvider(provider) + const routerModels = useRouterModels({ + provider: shouldFetchRouterModels ? provider : undefined, + enabled: shouldFetchRouterModels, + }) + const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId) const lmStudioModels = useLmStudioModels(lmStudioModelId) const ollamaModels = useOllamaModels(ollamaModelId) + // Compute readiness only for the data actually needed for the selected provider + const needRouterModels = shouldFetchRouterModels + const needOpenRouterProviders = provider === "openrouter" + const needLmStudio = typeof lmStudioModelId !== "undefined" + const needOllama = typeof ollamaModelId !== "undefined" + + const isReady = + (!needLmStudio || typeof lmStudioModels.data !== "undefined") && + (!needOllama || typeof ollamaModels.data !== "undefined") && + (!needRouterModels || typeof routerModels.data !== "undefined") && + (!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined") + const { id, info } = - apiConfiguration && - (typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") && - (typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") && - typeof routerModels.data !== "undefined" && - typeof openRouterModelProviders.data !== "undefined" + apiConfiguration && isReady ? getSelectedModel({ provider, apiConfiguration, - routerModels: routerModels.data, - openRouterModelProviders: openRouterModelProviders.data, - lmStudioModels: lmStudioModels.data, - ollamaModels: ollamaModels.data, + routerModels: (routerModels.data || {}) as RouterModels, + openRouterModelProviders: (openRouterModelProviders.data || {}) as Record, + lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined, + ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined, }) : { id: anthropicDefaultModelId, info: undefined } @@ -99,13 +114,15 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { id, info, isLoading: - routerModels.isLoading || - openRouterModelProviders.isLoading || - (apiConfiguration?.lmStudioModelId && lmStudioModels!.isLoading), + (needRouterModels && routerModels.isLoading) || + (needOpenRouterProviders && openRouterModelProviders.isLoading) || + (needLmStudio && lmStudioModels!.isLoading) || + (needOllama && ollamaModels!.isLoading), isError: - routerModels.isError || - openRouterModelProviders.isError || - (apiConfiguration?.lmStudioModelId && lmStudioModels!.isError), + (needRouterModels && routerModels.isError) || + (needOpenRouterProviders && openRouterModelProviders.isError) || + (needLmStudio && lmStudioModels!.isError) || + (needOllama && ollamaModels!.isError), } }