Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { webviewMessageHandler } from "../webviewMessageHandler"
import type { ClineProvider } from "../ClineProvider"

// Mock vscode (minimal)
vi.mock("vscode", () => ({
window: {
showErrorMessage: vi.fn(),
showWarningMessage: vi.fn(),
showInformationMessage: vi.fn(),
},
workspace: {
workspaceFolders: undefined,
getConfiguration: vi.fn(() => ({
get: vi.fn(),
update: vi.fn(),
})),
},
env: {
clipboard: { writeText: vi.fn() },
openExternal: vi.fn(),
},
commands: {
executeCommand: vi.fn(),
},
Uri: {
parse: vi.fn((s: string) => ({ toString: () => s })),
file: vi.fn((p: string) => ({ fsPath: p })),
},
ConfigurationTarget: {
Global: 1,
Workspace: 2,
WorkspaceFolder: 3,
},
}))

// Mock modelCache getModels/flushModels used by the handler
const getModelsMock = vi.fn()
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
getModels: (...args: any[]) => getModelsMock(...args),
flushModels: vi.fn(),
}))

describe("webviewMessageHandler - requestRouterModels providers filter", () => {
let mockProvider: ClineProvider & {
postMessageToWebview: ReturnType<typeof vi.fn>
getState: ReturnType<typeof vi.fn>
contextProxy: any
log: ReturnType<typeof vi.fn>
}

beforeEach(() => {
vi.clearAllMocks()

mockProvider = {
// Only methods used by this code path
postMessageToWebview: vi.fn(),
getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }),
contextProxy: {
getValue: vi.fn(),
setValue: vi.fn(),
globalStorageUri: { fsPath: "/mock/storage" },
},
log: vi.fn(),
} as any

// Default mock: return distinct model maps per provider so we can verify keys
getModelsMock.mockImplementation(async (options: any) => {
switch (options?.provider) {
case "roo":
return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } }
case "openrouter":
return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } }
case "requesty":
return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } }
case "deepinfra":
return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } }
case "glama":
return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } }
case "unbound":
return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } }
case "vercel-ai-gateway":
return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } }
case "io-intelligence":
return { "io/model": { contextWindow: 8192, supportsPromptCache: false } }
case "litellm":
return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } }
default:
return {}
}
})
})

it("fetches only requested provider when values.providers is present (['roo'])", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
values: { providers: ["roo"] },
} as any,
)

// Should post a single routerModels message
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }),
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const payload = call[0]
const routerModels = payload.routerModels as Record<string, Record<string, any>>

// Only "roo" key should be present
const keys = Object.keys(routerModels)
expect(keys).toEqual(["roo"])
expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet")

// getModels should have been called exactly once for roo
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
expect(providersCalled).toEqual(["roo"])
})

it("defaults to aggregate fetching when no providers filter is sent", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
} as any,
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const routerModels = call[0].routerModels as Record<string, Record<string, any>>

// Aggregate handler initializes many known routers - ensure a few expected keys exist
expect(routerModels).toHaveProperty("openrouter")
expect(routerModels).toHaveProperty("roo")
expect(routerModels).toHaveProperty("requesty")
})

it("supports filtering another single provider (['openrouter'])", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
values: { providers: ["openrouter"] },
} as any,
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
const keys = Object.keys(routerModels)

expect(keys).toEqual(["openrouter"])
expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5")

const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
expect(providersCalled).toEqual(["openrouter"])
})
})
81 changes: 44 additions & 37 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -757,20 +757,38 @@ export const webviewMessageHandler = async (
case "requestRouterModels":
const { apiConfiguration } = await provider.getState()

const routerModels: Record<RouterName, ModelRecord> = {
openrouter: {},
"vercel-ai-gateway": {},
huggingface: {},
litellm: {},
deepinfra: {},
"io-intelligence": {},
requesty: {},
unbound: {},
glama: {},
ollama: {},
lmstudio: {},
roo: {},
}
// Optional providers filter coming from the webview
const providersFilterRaw = Array.isArray(message?.values?.providers) ? message.values.providers : undefined
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this filter to be an array instead of a single provider?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, if we only want to fetch one at a time I think we should make a function that only fetches one and use that instead of having the conditional logic. The only reason for backwards compatibility is that we broke this into multiple PRs, right? That seems ok but I think we should make a new method for the single lookup and then delete the old one once the webview is updated. Or just do it all in one PR, idk

const requestedProviders = providersFilterRaw
?.filter((p: unknown) => typeof p === "string")
.map((p: string) => {
try {
return toRouterName(p)
} catch {
return undefined
}
})
.filter((p): p is RouterName => !!p)

const hasFilter = !!requestedProviders && requestedProviders.length > 0
const requestedSet = new Set<RouterName>(requestedProviders || [])

const routerModels: Record<RouterName, ModelRecord> = hasFilter
? ({} as Record<RouterName, ModelRecord>)
: {
openrouter: {},
"vercel-ai-gateway": {},
huggingface: {},
litellm: {},
deepinfra: {},
"io-intelligence": {},
requesty: {},
unbound: {},
glama: {},
ollama: {},
lmstudio: {},
roo: {},
}

const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
try {
Expand All @@ -785,7 +803,8 @@ export const webviewMessageHandler = async (
}
}

const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [
// Base candidates (only those handled by this aggregate fetcher)
const candidates: { key: RouterName; options: GetModelsOptions }[] = [
{ key: "openrouter", options: { provider: "openrouter" } },
{
key: "requesty",
Expand Down Expand Up @@ -818,29 +837,28 @@ export const webviewMessageHandler = async (
},
]

// Add IO Intelligence if API key is provided.
const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey

if (ioIntelligenceApiKey) {
modelFetchPromises.push({
// IO Intelligence is conditional on api key
if (apiConfiguration.ioIntelligenceApiKey) {
candidates.push({
key: "io-intelligence",
options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey },
options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey },
})
}

// Don't fetch Ollama and LM Studio models by default anymore.
// They have their own specific handlers: requestOllamaModels and requestLmStudioModels.

// LiteLLM is conditional on baseUrl+apiKey
const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey
const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl

if (litellmApiKey && litellmBaseUrl) {
modelFetchPromises.push({
candidates.push({
key: "litellm",
options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },
})
}

// Apply providers filter (if any)
const modelFetchPromises = candidates.filter(({ key }) => (!hasFilter ? true : requestedSet.has(key)))

const results = await Promise.allSettled(
modelFetchPromises.map(async ({ key, options }) => {
const models = await safeGetModels(options)
Expand All @@ -854,18 +872,7 @@ export const webviewMessageHandler = async (
if (result.status === "fulfilled") {
routerModels[routerName] = result.value.models

// Ollama and LM Studio settings pages still need these events.
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "ollamaModels",
ollamaModels: result.value.models,
})
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "lmStudioModels",
lmStudioModels: result.value.models,
})
}
// Ollama and LM Studio settings pages still need these events. They are not fetched here.
} else {
// Handle rejection: Post a specific error message for this provider.
const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)
Expand Down