Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { webviewMessageHandler } from "../webviewMessageHandler"
import type { ClineProvider } from "../ClineProvider"

// Mock vscode (minimal)
vi.mock("vscode", () => ({
window: {
showErrorMessage: vi.fn(),
showWarningMessage: vi.fn(),
showInformationMessage: vi.fn(),
},
workspace: {
workspaceFolders: undefined,
getConfiguration: vi.fn(() => ({
get: vi.fn(),
update: vi.fn(),
})),
},
env: {
clipboard: { writeText: vi.fn() },
openExternal: vi.fn(),
},
commands: {
executeCommand: vi.fn(),
},
Uri: {
parse: vi.fn((s: string) => ({ toString: () => s })),
file: vi.fn((p: string) => ({ fsPath: p })),
},
ConfigurationTarget: {
Global: 1,
Workspace: 2,
WorkspaceFolder: 3,
},
}))

// Mock modelCache getModels/flushModels used by the handler
const getModelsMock = vi.fn()
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
getModels: (...args: any[]) => getModelsMock(...args),
flushModels: vi.fn(),
}))

describe("webviewMessageHandler - requestRouterModels providers filter", () => {
let mockProvider: ClineProvider & {
postMessageToWebview: ReturnType<typeof vi.fn>
getState: ReturnType<typeof vi.fn>
contextProxy: any
log: ReturnType<typeof vi.fn>
}

beforeEach(() => {
vi.clearAllMocks()

mockProvider = {
// Only methods used by this code path
postMessageToWebview: vi.fn(),
getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }),
contextProxy: {
getValue: vi.fn(),
setValue: vi.fn(),
globalStorageUri: { fsPath: "/mock/storage" },
},
log: vi.fn(),
} as any

// Default mock: return distinct model maps per provider so we can verify keys
getModelsMock.mockImplementation(async (options: any) => {
switch (options?.provider) {
case "roo":
return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } }
case "openrouter":
return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } }
case "requesty":
return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } }
case "deepinfra":
return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } }
case "glama":
return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } }
case "unbound":
return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } }
case "vercel-ai-gateway":
return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } }
case "io-intelligence":
return { "io/model": { contextWindow: 8192, supportsPromptCache: false } }
case "litellm":
return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } }
default:
return {}
}
})
})

it("fetches only requested provider when values.providers is present (['roo'])", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
values: { providers: ["roo"] },
} as any,
)

// Should post a single routerModels message
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }),
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const payload = call[0]
const routerModels = payload.routerModels as Record<string, Record<string, any>>

// Only "roo" key should be present
const keys = Object.keys(routerModels)
expect(keys).toEqual(["roo"])
expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet")

// getModels should have been called exactly once for roo
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
expect(providersCalled).toEqual(["roo"])
})

it("defaults to aggregate fetching when no providers filter is sent", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
} as any,
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const routerModels = call[0].routerModels as Record<string, Record<string, any>>

// Aggregate handler initializes many known routers - ensure a few expected keys exist
expect(routerModels).toHaveProperty("openrouter")
expect(routerModels).toHaveProperty("roo")
expect(routerModels).toHaveProperty("requesty")
})

it("supports filtering another single provider (['openrouter'])", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
values: { providers: ["openrouter"] },
} as any,
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
const keys = Object.keys(routerModels)

expect(keys).toEqual(["openrouter"])
expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5")

const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
expect(providersCalled).toEqual(["openrouter"])
})
})
81 changes: 44 additions & 37 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -757,20 +757,38 @@ export const webviewMessageHandler = async (
case "requestRouterModels":
const { apiConfiguration } = await provider.getState()

const routerModels: Record<RouterName, ModelRecord> = {
openrouter: {},
"vercel-ai-gateway": {},
huggingface: {},
litellm: {},
deepinfra: {},
"io-intelligence": {},
requesty: {},
unbound: {},
glama: {},
ollama: {},
lmstudio: {},
roo: {},
}
// Optional providers filter coming from the webview
const providersFilterRaw = Array.isArray(message?.values?.providers) ? message.values.providers : undefined
const requestedProviders = providersFilterRaw
?.filter((p: unknown) => typeof p === "string")
.map((p: string) => {
try {
return toRouterName(p)
} catch {
return undefined
}
})
.filter((p): p is RouterName => !!p)

const hasFilter = !!requestedProviders && requestedProviders.length > 0
const requestedSet = new Set<RouterName>(requestedProviders || [])

const routerModels: Record<RouterName, ModelRecord> = hasFilter
? ({} as Record<RouterName, ModelRecord>)
: {
openrouter: {},
"vercel-ai-gateway": {},
huggingface: {},
litellm: {},
deepinfra: {},
"io-intelligence": {},
requesty: {},
unbound: {},
glama: {},
ollama: {},
lmstudio: {},
roo: {},
}

const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
try {
Expand All @@ -785,7 +803,8 @@ export const webviewMessageHandler = async (
}
}

const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [
// Base candidates (only those handled by this aggregate fetcher)
const candidates: { key: RouterName; options: GetModelsOptions }[] = [
{ key: "openrouter", options: { provider: "openrouter" } },
{
key: "requesty",
Expand Down Expand Up @@ -818,29 +837,28 @@ export const webviewMessageHandler = async (
},
]

// Add IO Intelligence if API key is provided.
const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey

if (ioIntelligenceApiKey) {
modelFetchPromises.push({
// IO Intelligence is conditional on api key
if (apiConfiguration.ioIntelligenceApiKey) {
candidates.push({
key: "io-intelligence",
options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey },
options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey },
})
}

// Don't fetch Ollama and LM Studio models by default anymore.
// They have their own specific handlers: requestOllamaModels and requestLmStudioModels.

// LiteLLM is conditional on baseUrl+apiKey
const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey
const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl

if (litellmApiKey && litellmBaseUrl) {
modelFetchPromises.push({
candidates.push({
key: "litellm",
options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },
})
}

// Apply providers filter (if any)
const modelFetchPromises = candidates.filter(({ key }) => (!hasFilter ? true : requestedSet.has(key)))

const results = await Promise.allSettled(
modelFetchPromises.map(async ({ key, options }) => {
const models = await safeGetModels(options)
Expand All @@ -854,18 +872,7 @@ export const webviewMessageHandler = async (
if (result.status === "fulfilled") {
routerModels[routerName] = result.value.models

// Ollama and LM Studio settings pages still need these events.
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "ollamaModels",
ollamaModels: result.value.models,
})
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "lmStudioModels",
lmStudioModels: result.value.models,
})
}
// Ollama and LM Studio settings pages still need these events. They are not fetched here.
} else {
// Handle rejection: Post a specific error message for this provider.
const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ describe("useSelectedModel", () => {
})

describe("loading and error states", () => {
it("should return loading state when router models are loading", () => {
it("should NOT set loading when router models are loading but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: undefined,
isLoading: true,
Expand All @@ -307,10 +307,11 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isLoading).toBe(true)
// With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false
expect(result.current.isLoading).toBe(false)
})

it("should return loading state when open router model providers are loading", () => {
it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} },
isLoading: false,
Expand All @@ -326,10 +327,11 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isLoading).toBe(true)
// With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false
expect(result.current.isLoading).toBe(false)
})

it("should return error state when either hook has an error", () => {
it("should NOT set error when hooks error but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: undefined,
isLoading: false,
Expand All @@ -345,7 +347,8 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isError).toBe(true)
// Error from gated routerModels should not bubble for static provider default
expect(result.current.isError).toBe(false)
})
})

Expand Down
32 changes: 29 additions & 3 deletions webview-ui/src/components/ui/hooks/useRouterModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"

import { vscode } from "@src/utils/vscode"

const getRouterModels = async () =>
type UseRouterModelsOptions = {
providers?: string[] // subset filter (e.g. ["roo"])
enabled?: boolean // gate fetching entirely
}

let __routerModelsRequestCount = 0

const getRouterModels = async (providers?: string[]) =>
new Promise<RouterModels>((resolve, reject) => {
const requestId = ++__routerModelsRequestCount
const cleanup = () => {
window.removeEventListener("message", handler)
}
Expand All @@ -24,6 +32,10 @@ const getRouterModels = async () =>
cleanup()

if (message.routerModels) {
const keys = Object.keys(message.routerModels || {})
console.debug(
`[useRouterModels] response #${requestId} providers=${JSON.stringify(providers || "all")} keys=${keys.join(",")}`,
)
resolve(message.routerModels)
} else {
reject(new Error("No router models in response"))
Expand All @@ -32,7 +44,21 @@ const getRouterModels = async () =>
}

window.addEventListener("message", handler)
vscode.postMessage({ type: "requestRouterModels" })
console.debug(
`[useRouterModels] request #${requestId} providers=${JSON.stringify(providers && providers.length ? providers : "all")}`,
)
if (providers && providers.length > 0) {
vscode.postMessage({ type: "requestRouterModels", values: { providers } })
} else {
vscode.postMessage({ type: "requestRouterModels" })
}
})

export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels })
export const useRouterModels = (opts: UseRouterModelsOptions = {}) => {
const providers = opts.providers && opts.providers.length ? [...opts.providers] : undefined
return useQuery({
queryKey: ["routerModels", providers?.slice().sort().join(",") || "all"],
queryFn: () => getRouterModels(providers),
enabled: opts.enabled !== false,
})
}
Loading