diff --git a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts new file mode 100644 index 000000000000..7954dc14a26a --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts @@ -0,0 +1,167 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { webviewMessageHandler } from "../webviewMessageHandler" +import type { ClineProvider } from "../ClineProvider" + +// Mock vscode (minimal) +vi.mock("vscode", () => ({ + window: { + showErrorMessage: vi.fn(), + showWarningMessage: vi.fn(), + showInformationMessage: vi.fn(), + }, + workspace: { + workspaceFolders: undefined, + getConfiguration: vi.fn(() => ({ + get: vi.fn(), + update: vi.fn(), + })), + }, + env: { + clipboard: { writeText: vi.fn() }, + openExternal: vi.fn(), + }, + commands: { + executeCommand: vi.fn(), + }, + Uri: { + parse: vi.fn((s: string) => ({ toString: () => s })), + file: vi.fn((p: string) => ({ fsPath: p })), + }, + ConfigurationTarget: { + Global: 1, + Workspace: 2, + WorkspaceFolder: 3, + }, +})) + +// Mock modelCache getModels/flushModels used by the handler +const getModelsMock = vi.fn() +vi.mock("../../../api/providers/fetchers/modelCache", () => ({ + getModels: (...args: any[]) => getModelsMock(...args), + flushModels: vi.fn(), +})) + +describe("webviewMessageHandler - requestRouterModels providers filter", () => { + let mockProvider: ClineProvider & { + postMessageToWebview: ReturnType + getState: ReturnType + contextProxy: any + log: ReturnType + } + + beforeEach(() => { + vi.clearAllMocks() + + mockProvider = { + // Only methods used by this code path + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }), + contextProxy: { + getValue: vi.fn(), + setValue: vi.fn(), + globalStorageUri: { fsPath: "/mock/storage" }, + }, + log: vi.fn(), + } as any + + // Default mock: return distinct model maps per provider so we can verify keys + getModelsMock.mockImplementation(async (options: any) => { + switch (options?.provider) { + case "roo": + return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } } + case "openrouter": + return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } } + case "requesty": + return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } } + case "deepinfra": + return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } } + case "glama": + return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } } + case "unbound": + return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } } + case "vercel-ai-gateway": + return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } } + case "io-intelligence": + return { "io/model": { contextWindow: 8192, supportsPromptCache: false } } + case "litellm": + return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } } + default: + return {} + } + }) + }) + + it("fetches only requested provider when values.providers is present (['roo'])", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { providers: ["roo"] }, + } as any, + ) + + // Should post a single routerModels message + expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith( + expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }), + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const payload = call[0] + const routerModels = payload.routerModels as Record> + + // Only "roo" key should be present + const keys = Object.keys(routerModels) + expect(keys).toEqual(["roo"]) + expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet") + + // getModels should have been called exactly once for roo + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["roo"]) + }) + + it("defaults to aggregate fetching when no providers filter is sent", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + + // Aggregate handler initializes many known routers - ensure a few expected keys exist + expect(routerModels).toHaveProperty("openrouter") + expect(routerModels).toHaveProperty("roo") + expect(routerModels).toHaveProperty("requesty") + }) + + it("supports filtering another single provider (['openrouter'])", async () => { + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { providers: ["openrouter"] }, + } as any, + ) + + const call = (mockProvider.postMessageToWebview as any).mock.calls.find( + (c: any[]) => c[0]?.type === "routerModels", + ) + expect(call).toBeTruthy() + const routerModels = call[0].routerModels as Record> + const keys = Object.keys(routerModels) + + expect(keys).toEqual(["openrouter"]) + expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5") + + const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) + expect(providersCalled).toEqual(["openrouter"]) + }) +}) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index e32b818a96e9..4a1495935542 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -757,20 +757,38 @@ export const webviewMessageHandler = async ( case "requestRouterModels": const { apiConfiguration } = await provider.getState() - const routerModels: Record = { - openrouter: {}, - "vercel-ai-gateway": {}, - huggingface: {}, - litellm: {}, - deepinfra: {}, - "io-intelligence": {}, - requesty: {}, - unbound: {}, - glama: {}, - ollama: {}, - lmstudio: {}, - roo: {}, - } + // Optional providers filter coming from the webview + const providersFilterRaw = Array.isArray(message?.values?.providers) ? message.values.providers : undefined + const requestedProviders = providersFilterRaw + ?.filter((p: unknown) => typeof p === "string") + .map((p: string) => { + try { + return toRouterName(p) + } catch { + return undefined + } + }) + .filter((p): p is RouterName => !!p) + + const hasFilter = !!requestedProviders && requestedProviders.length > 0 + const requestedSet = new Set(requestedProviders || []) + + const routerModels: Record = hasFilter + ? ({} as Record) + : { + openrouter: {}, + "vercel-ai-gateway": {}, + huggingface: {}, + litellm: {}, + deepinfra: {}, + "io-intelligence": {}, + requesty: {}, + unbound: {}, + glama: {}, + ollama: {}, + lmstudio: {}, + roo: {}, + } const safeGetModels = async (options: GetModelsOptions): Promise => { try { @@ -785,7 +803,8 @@ export const webviewMessageHandler = async ( } } - const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [ + // Base candidates (only those handled by this aggregate fetcher) + const candidates: { key: RouterName; options: GetModelsOptions }[] = [ { key: "openrouter", options: { provider: "openrouter" } }, { key: "requesty", @@ -818,29 +837,28 @@ export const webviewMessageHandler = async ( }, ] - // Add IO Intelligence if API key is provided. - const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey - - if (ioIntelligenceApiKey) { - modelFetchPromises.push({ + // IO Intelligence is conditional on api key + if (apiConfiguration.ioIntelligenceApiKey) { + candidates.push({ key: "io-intelligence", - options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey }, + options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey }, }) } - // Don't fetch Ollama and LM Studio models by default anymore. - // They have their own specific handlers: requestOllamaModels and requestLmStudioModels. - + // LiteLLM is conditional on baseUrl+apiKey const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl if (litellmApiKey && litellmBaseUrl) { - modelFetchPromises.push({ + candidates.push({ key: "litellm", options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, }) } + // Apply providers filter (if any) + const modelFetchPromises = candidates.filter(({ key }) => (!hasFilter ? true : requestedSet.has(key))) + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) @@ -854,18 +872,7 @@ export const webviewMessageHandler = async ( if (result.status === "fulfilled") { routerModels[routerName] = result.value.models - // Ollama and LM Studio settings pages still need these events. - if (routerName === "ollama" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "ollamaModels", - ollamaModels: result.value.models, - }) - } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "lmStudioModels", - lmStudioModels: result.value.models, - }) - } + // Ollama and LM Studio settings pages still need these events. They are not fetched here. } else { // Handle rejection: Post a specific error message for this provider. const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason) diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index e49944a9975c..c8afe3a46272 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -291,7 +291,7 @@ describe("useSelectedModel", () => { }) describe("loading and error states", () => { - it("should return loading state when router models are loading", () => { + it("should NOT set loading when router models are loading but provider is static (anthropic)", () => { mockUseRouterModels.mockReturnValue({ data: undefined, isLoading: true, @@ -307,10 +307,11 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(), { wrapper }) - expect(result.current.isLoading).toBe(true) + // With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false + expect(result.current.isLoading).toBe(false) }) - it("should return loading state when open router model providers are loading", () => { + it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => { mockUseRouterModels.mockReturnValue({ data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} }, isLoading: false, @@ -326,10 +327,11 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(), { wrapper }) - expect(result.current.isLoading).toBe(true) + // With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false + expect(result.current.isLoading).toBe(false) }) - it("should return error state when either hook has an error", () => { + it("should NOT set error when hooks error but provider is static (anthropic)", () => { mockUseRouterModels.mockReturnValue({ data: undefined, isLoading: false, @@ -345,7 +347,8 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(), { wrapper }) - expect(result.current.isError).toBe(true) + // Error from gated routerModels should not bubble for static provider default + expect(result.current.isError).toBe(false) }) }) diff --git a/webview-ui/src/components/ui/hooks/useRouterModels.ts b/webview-ui/src/components/ui/hooks/useRouterModels.ts index 0ca68cc27a6e..e2c4c9a1ca86 100644 --- a/webview-ui/src/components/ui/hooks/useRouterModels.ts +++ b/webview-ui/src/components/ui/hooks/useRouterModels.ts @@ -5,8 +5,16 @@ import { ExtensionMessage } from "@roo/ExtensionMessage" import { vscode } from "@src/utils/vscode" -const getRouterModels = async () => +type UseRouterModelsOptions = { + providers?: string[] // subset filter (e.g. ["roo"]) + enabled?: boolean // gate fetching entirely +} + +let __routerModelsRequestCount = 0 + +const getRouterModels = async (providers?: string[]) => new Promise((resolve, reject) => { + const requestId = ++__routerModelsRequestCount const cleanup = () => { window.removeEventListener("message", handler) } @@ -24,6 +32,10 @@ const getRouterModels = async () => cleanup() if (message.routerModels) { + const keys = Object.keys(message.routerModels || {}) + console.debug( + `[useRouterModels] response #${requestId} providers=${JSON.stringify(providers || "all")} keys=${keys.join(",")}`, + ) resolve(message.routerModels) } else { reject(new Error("No router models in response")) @@ -32,7 +44,21 @@ const getRouterModels = async () => } window.addEventListener("message", handler) - vscode.postMessage({ type: "requestRouterModels" }) + console.debug( + `[useRouterModels] request #${requestId} providers=${JSON.stringify(providers && providers.length ? providers : "all")}`, + ) + if (providers && providers.length > 0) { + vscode.postMessage({ type: "requestRouterModels", values: { providers } }) + } else { + vscode.postMessage({ type: "requestRouterModels" }) + } }) -export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels }) +export const useRouterModels = (opts: UseRouterModelsOptions = {}) => { + const providers = opts.providers && opts.providers.length ? [...opts.providers] : undefined + return useQuery({ + queryKey: ["routerModels", providers?.slice().sort().join(",") || "all"], + queryFn: () => getRouterModels(providers), + enabled: opts.enabled !== false, + }) +} diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 55fdd120bd3c..99f7c944b472 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -67,30 +67,56 @@ import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders" import { useLmStudioModels } from "./useLmStudioModels" import { useOllamaModels } from "./useOllamaModels" +const DYNAMIC_ROUTER_PROVIDERS = new Set([ + "openrouter", + "vercel-ai-gateway", + "litellm", + "deepinfra", + "io-intelligence", + "requesty", + "unbound", + "glama", + "roo", +]) + export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const provider = apiConfiguration?.apiProvider || "anthropic" const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined - const routerModels = useRouterModels() + // Only fetch router models for dynamic router providers we actually need + const shouldFetchRouterModels = DYNAMIC_ROUTER_PROVIDERS.has(provider as ProviderName) + const routerModels = useRouterModels({ + providers: shouldFetchRouterModels ? [provider] : undefined, + enabled: shouldFetchRouterModels, // disable entirely for static providers + }) + const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId) const lmStudioModels = useLmStudioModels(lmStudioModelId) const ollamaModels = useOllamaModels(ollamaModelId) + // Compute readiness only for the data actually needed for the selected provider + const needRouterModels = shouldFetchRouterModels + const needOpenRouterProviders = provider === "openrouter" + const needLmStudio = typeof lmStudioModelId !== "undefined" + const needOllama = typeof ollamaModelId !== "undefined" + + const isReady = + (!needLmStudio || typeof lmStudioModels.data !== "undefined") && + (!needOllama || typeof ollamaModels.data !== "undefined") && + (!needRouterModels || typeof routerModels.data !== "undefined") && + (!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined") + const { id, info } = - apiConfiguration && - (typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") && - (typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") && - typeof routerModels.data !== "undefined" && - typeof openRouterModelProviders.data !== "undefined" + apiConfiguration && isReady ? getSelectedModel({ provider, apiConfiguration, - routerModels: routerModels.data, - openRouterModelProviders: openRouterModelProviders.data, - lmStudioModels: lmStudioModels.data, - ollamaModels: ollamaModels.data, + routerModels: (routerModels.data || ({} as RouterModels)) as RouterModels, + openRouterModelProviders: (openRouterModelProviders.data || {}) as Record, + lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined, + ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined, }) : { id: anthropicDefaultModelId, info: undefined } @@ -99,13 +125,15 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { id, info, isLoading: - routerModels.isLoading || - openRouterModelProviders.isLoading || - (apiConfiguration?.lmStudioModelId && lmStudioModels!.isLoading), + (needRouterModels && routerModels.isLoading) || + (needOpenRouterProviders && openRouterModelProviders.isLoading) || + (needLmStudio && lmStudioModels!.isLoading) || + (needOllama && ollamaModels!.isLoading), isError: - routerModels.isError || - openRouterModelProviders.isError || - (apiConfiguration?.lmStudioModelId && lmStudioModels!.isError), + (needRouterModels && routerModels.isError) || + (needOpenRouterProviders && openRouterModelProviders.isError) || + (needLmStudio && lmStudioModels!.isError) || + (needOllama && ollamaModels!.isError), } } diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index 7c68795040b5..6443ccad93d5 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -440,12 +440,13 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode // Watch for authentication state changes and refresh Roo models useEffect(() => { const currentAuth = state.cloudIsAuthenticated ?? false - if (!prevCloudIsAuthenticated && currentAuth) { - // User just authenticated - refresh Roo models with the new auth token + const currentProvider = state.apiConfiguration?.apiProvider + if (!prevCloudIsAuthenticated && currentAuth && currentProvider === "roo") { + // User just authenticated and Roo is the active provider - refresh Roo models vscode.postMessage({ type: "requestRooModels" }) } setPrevCloudIsAuthenticated(currentAuth) - }, [state.cloudIsAuthenticated, prevCloudIsAuthenticated]) + }, [state.cloudIsAuthenticated, prevCloudIsAuthenticated, state.apiConfiguration?.apiProvider]) const contextValue: ExtensionStateContextType = { ...state, diff --git a/webview-ui/src/context/__tests__/ExtensionStateContext.roo-auth-gate.spec.tsx b/webview-ui/src/context/__tests__/ExtensionStateContext.roo-auth-gate.spec.tsx new file mode 100644 index 000000000000..d62adf26e935 --- /dev/null +++ b/webview-ui/src/context/__tests__/ExtensionStateContext.roo-auth-gate.spec.tsx @@ -0,0 +1,75 @@ +import { render, waitFor } from "@/utils/test-utils" +import React from "react" + +vi.mock("@src/utils/vscode", () => ({ + vscode: { + postMessage: vi.fn(), + }, +})) + +import { ExtensionStateContextProvider } from "@src/context/ExtensionStateContext" +import { vscode } from "@src/utils/vscode" + +describe("ExtensionStateContext Roo auth gate", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + function postStateMessage(state: any) { + window.dispatchEvent( + new MessageEvent("message", { + data: { + type: "state", + state, + }, + }), + ) + } + + it("does not post requestRooModels when auth flips and provider !== 'roo'", async () => { + render( + +
+ , + ) + + // Flip auth to true with a non-roo provider (anthropic) + postStateMessage({ + cloudIsAuthenticated: true, + apiConfiguration: { apiProvider: "anthropic" }, + }) + + // Should NOT fire auth-driven Roo refresh + await waitFor(() => { + const calls = (vscode.postMessage as any).mock.calls as any[][] + const hasRequest = calls.some((c) => c[0]?.type === "requestRooModels") + expect(hasRequest).toBe(false) + }) + }) + + it("posts requestRooModels when auth flips and provider === 'roo'", async () => { + render( + +
+ , + ) + + // Ensure prev false (explicit) + postStateMessage({ + cloudIsAuthenticated: false, + apiConfiguration: { apiProvider: "roo" }, + }) + + vi.clearAllMocks() + + // Flip to true with provider roo - should trigger + postStateMessage({ + cloudIsAuthenticated: true, + apiConfiguration: { apiProvider: "roo" }, + }) + + await waitFor(() => { + expect(vscode.postMessage).toHaveBeenCalledWith({ type: "requestRooModels" }) + }) + }) +})