Skip to content

Commit 31de103

Browse files
authored
feat: optimize router model fetching with single-provider filtering (#8956)
1 parent 388d405 commit 31de103

File tree

5 files changed

+275
-63
lines changed

5 files changed

+275
-63
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import { webviewMessageHandler } from "../webviewMessageHandler"
3+
import type { ClineProvider } from "../ClineProvider"
4+
5+
// Mock vscode (minimal)
6+
vi.mock("vscode", () => ({
7+
window: {
8+
showErrorMessage: vi.fn(),
9+
showWarningMessage: vi.fn(),
10+
showInformationMessage: vi.fn(),
11+
},
12+
workspace: {
13+
workspaceFolders: undefined,
14+
getConfiguration: vi.fn(() => ({
15+
get: vi.fn(),
16+
update: vi.fn(),
17+
})),
18+
},
19+
env: {
20+
clipboard: { writeText: vi.fn() },
21+
openExternal: vi.fn(),
22+
},
23+
commands: {
24+
executeCommand: vi.fn(),
25+
},
26+
Uri: {
27+
parse: vi.fn((s: string) => ({ toString: () => s })),
28+
file: vi.fn((p: string) => ({ fsPath: p })),
29+
},
30+
ConfigurationTarget: {
31+
Global: 1,
32+
Workspace: 2,
33+
WorkspaceFolder: 3,
34+
},
35+
}))
36+
37+
// Mock modelCache getModels/flushModels used by the handler
38+
const getModelsMock = vi.fn()
39+
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
40+
getModels: (...args: any[]) => getModelsMock(...args),
41+
flushModels: vi.fn(),
42+
}))
43+
44+
describe("webviewMessageHandler - requestRouterModels provider filter", () => {
45+
let mockProvider: ClineProvider & {
46+
postMessageToWebview: ReturnType<typeof vi.fn>
47+
getState: ReturnType<typeof vi.fn>
48+
contextProxy: any
49+
log: ReturnType<typeof vi.fn>
50+
}
51+
52+
beforeEach(() => {
53+
vi.clearAllMocks()
54+
55+
mockProvider = {
56+
// Only methods used by this code path
57+
postMessageToWebview: vi.fn(),
58+
getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }),
59+
contextProxy: {
60+
getValue: vi.fn(),
61+
setValue: vi.fn(),
62+
globalStorageUri: { fsPath: "/mock/storage" },
63+
},
64+
log: vi.fn(),
65+
} as any
66+
67+
// Default mock: return distinct model maps per provider so we can verify keys
68+
getModelsMock.mockImplementation(async (options: any) => {
69+
switch (options?.provider) {
70+
case "roo":
71+
return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } }
72+
case "openrouter":
73+
return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } }
74+
case "requesty":
75+
return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } }
76+
case "deepinfra":
77+
return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } }
78+
case "glama":
79+
return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } }
80+
case "unbound":
81+
return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } }
82+
case "vercel-ai-gateway":
83+
return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } }
84+
case "io-intelligence":
85+
return { "io/model": { contextWindow: 8192, supportsPromptCache: false } }
86+
case "litellm":
87+
return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } }
88+
default:
89+
return {}
90+
}
91+
})
92+
})
93+
94+
it("fetches only requested provider when values.provider is present ('roo')", async () => {
95+
await webviewMessageHandler(
96+
mockProvider as any,
97+
{
98+
type: "requestRouterModels",
99+
values: { provider: "roo" },
100+
} as any,
101+
)
102+
103+
// Should post a single routerModels message
104+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
105+
expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }),
106+
)
107+
108+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
109+
(c: any[]) => c[0]?.type === "routerModels",
110+
)
111+
expect(call).toBeTruthy()
112+
const payload = call[0]
113+
const routerModels = payload.routerModels as Record<string, Record<string, any>>
114+
115+
// Only "roo" key should be present
116+
const keys = Object.keys(routerModels)
117+
expect(keys).toEqual(["roo"])
118+
expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet")
119+
120+
// getModels should have been called exactly once for roo
121+
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
122+
expect(providersCalled).toEqual(["roo"])
123+
})
124+
125+
it("defaults to aggregate fetching when no provider filter is sent", async () => {
126+
await webviewMessageHandler(
127+
mockProvider as any,
128+
{
129+
type: "requestRouterModels",
130+
} as any,
131+
)
132+
133+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
134+
(c: any[]) => c[0]?.type === "routerModels",
135+
)
136+
expect(call).toBeTruthy()
137+
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
138+
139+
// Aggregate handler initializes many known routers - ensure a few expected keys exist
140+
expect(routerModels).toHaveProperty("openrouter")
141+
expect(routerModels).toHaveProperty("roo")
142+
expect(routerModels).toHaveProperty("requesty")
143+
})
144+
145+
it("supports filtering another single provider ('openrouter')", async () => {
146+
await webviewMessageHandler(
147+
mockProvider as any,
148+
{
149+
type: "requestRouterModels",
150+
values: { provider: "openrouter" },
151+
} as any,
152+
)
153+
154+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
155+
(c: any[]) => c[0]?.type === "routerModels",
156+
)
157+
expect(call).toBeTruthy()
158+
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
159+
const keys = Object.keys(routerModels)
160+
161+
expect(keys).toEqual(["openrouter"])
162+
expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5")
163+
164+
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
165+
expect(providersCalled).toEqual(["openrouter"])
166+
})
167+
})

src/core/webview/webviewMessageHandler.ts

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -757,20 +757,26 @@ export const webviewMessageHandler = async (
757757
case "requestRouterModels":
758758
const { apiConfiguration } = await provider.getState()
759759

760-
const routerModels: Record<RouterName, ModelRecord> = {
761-
openrouter: {},
762-
"vercel-ai-gateway": {},
763-
huggingface: {},
764-
litellm: {},
765-
deepinfra: {},
766-
"io-intelligence": {},
767-
requesty: {},
768-
unbound: {},
769-
glama: {},
770-
ollama: {},
771-
lmstudio: {},
772-
roo: {},
773-
}
760+
// Optional single provider filter from webview
761+
const requestedProvider = message?.values?.provider
762+
const providerFilter = requestedProvider ? toRouterName(requestedProvider) : undefined
763+
764+
const routerModels: Record<RouterName, ModelRecord> = providerFilter
765+
? ({} as Record<RouterName, ModelRecord>)
766+
: {
767+
openrouter: {},
768+
"vercel-ai-gateway": {},
769+
huggingface: {},
770+
litellm: {},
771+
deepinfra: {},
772+
"io-intelligence": {},
773+
requesty: {},
774+
unbound: {},
775+
glama: {},
776+
ollama: {},
777+
lmstudio: {},
778+
roo: {},
779+
}
774780

775781
const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
776782
try {
@@ -785,7 +791,8 @@ export const webviewMessageHandler = async (
785791
}
786792
}
787793

788-
const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [
794+
// Base candidates (only those handled by this aggregate fetcher)
795+
const candidates: { key: RouterName; options: GetModelsOptions }[] = [
789796
{ key: "openrouter", options: { provider: "openrouter" } },
790797
{
791798
key: "requesty",
@@ -818,29 +825,30 @@ export const webviewMessageHandler = async (
818825
},
819826
]
820827

821-
// Add IO Intelligence if API key is provided.
822-
const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey
823-
824-
if (ioIntelligenceApiKey) {
825-
modelFetchPromises.push({
828+
// IO Intelligence is conditional on api key
829+
if (apiConfiguration.ioIntelligenceApiKey) {
830+
candidates.push({
826831
key: "io-intelligence",
827-
options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey },
832+
options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey },
828833
})
829834
}
830835

831-
// Don't fetch Ollama and LM Studio models by default anymore.
832-
// They have their own specific handlers: requestOllamaModels and requestLmStudioModels.
833-
836+
// LiteLLM is conditional on baseUrl+apiKey
834837
const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey
835838
const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl
836839

837840
if (litellmApiKey && litellmBaseUrl) {
838-
modelFetchPromises.push({
841+
candidates.push({
839842
key: "litellm",
840843
options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },
841844
})
842845
}
843846

847+
// Apply single provider filter if specified
848+
const modelFetchPromises = providerFilter
849+
? candidates.filter(({ key }) => key === providerFilter)
850+
: candidates
851+
844852
const results = await Promise.allSettled(
845853
modelFetchPromises.map(async ({ key, options }) => {
846854
const models = await safeGetModels(options)
@@ -854,18 +862,7 @@ export const webviewMessageHandler = async (
854862
if (result.status === "fulfilled") {
855863
routerModels[routerName] = result.value.models
856864

857-
// Ollama and LM Studio settings pages still need these events.
858-
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
859-
provider.postMessageToWebview({
860-
type: "ollamaModels",
861-
ollamaModels: result.value.models,
862-
})
863-
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
864-
provider.postMessageToWebview({
865-
type: "lmStudioModels",
866-
lmStudioModels: result.value.models,
867-
})
868-
}
865+
// Ollama and LM Studio settings pages still need these events. They are not fetched here.
869866
} else {
870867
// Handle rejection: Post a specific error message for this provider.
871868
const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)
@@ -882,7 +879,11 @@ export const webviewMessageHandler = async (
882879
}
883880
})
884881

885-
provider.postMessageToWebview({ type: "routerModels", routerModels })
882+
provider.postMessageToWebview({
883+
type: "routerModels",
884+
routerModels,
885+
values: providerFilter ? { provider: requestedProvider } : undefined,
886+
})
886887
break
887888
case "requestOllamaModels": {
888889
// Specific handler for Ollama models only.

webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ describe("useSelectedModel", () => {
291291
})
292292

293293
describe("loading and error states", () => {
294-
it("should return loading state when router models are loading", () => {
294+
it("should NOT set loading when router models are loading but provider is static (anthropic)", () => {
295295
mockUseRouterModels.mockReturnValue({
296296
data: undefined,
297297
isLoading: true,
@@ -307,10 +307,11 @@ describe("useSelectedModel", () => {
307307
const wrapper = createWrapper()
308308
const { result } = renderHook(() => useSelectedModel(), { wrapper })
309309

310-
expect(result.current.isLoading).toBe(true)
310+
// With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false
311+
expect(result.current.isLoading).toBe(false)
311312
})
312313

313-
it("should return loading state when open router model providers are loading", () => {
314+
it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => {
314315
mockUseRouterModels.mockReturnValue({
315316
data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} },
316317
isLoading: false,
@@ -326,10 +327,11 @@ describe("useSelectedModel", () => {
326327
const wrapper = createWrapper()
327328
const { result } = renderHook(() => useSelectedModel(), { wrapper })
328329

329-
expect(result.current.isLoading).toBe(true)
330+
// With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false
331+
expect(result.current.isLoading).toBe(false)
330332
})
331333

332-
it("should return error state when either hook has an error", () => {
334+
it("should NOT set error when hooks error but provider is static (anthropic)", () => {
333335
mockUseRouterModels.mockReturnValue({
334336
data: undefined,
335337
isLoading: false,
@@ -345,7 +347,8 @@ describe("useSelectedModel", () => {
345347
const wrapper = createWrapper()
346348
const { result } = renderHook(() => useSelectedModel(), { wrapper })
347349

348-
expect(result.current.isError).toBe(true)
350+
// Error from gated routerModels should not bubble for static provider default
351+
expect(result.current.isError).toBe(false)
349352
})
350353
})
351354

webview-ui/src/components/ui/hooks/useRouterModels.ts

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"
55

66
import { vscode } from "@src/utils/vscode"
77

8-
const getRouterModels = async () =>
8+
type UseRouterModelsOptions = {
9+
provider?: string // single provider filter (e.g. "roo")
10+
enabled?: boolean // gate fetching entirely
11+
}
12+
13+
const getRouterModels = async (provider?: string) =>
914
new Promise<RouterModels>((resolve, reject) => {
1015
const cleanup = () => {
1116
window.removeEventListener("message", handler)
@@ -20,6 +25,14 @@ const getRouterModels = async () =>
2025
const message: ExtensionMessage = event.data
2126

2227
if (message.type === "routerModels") {
28+
const msgProvider = message?.values?.provider as string | undefined
29+
30+
// Verify response matches request
31+
if (provider !== msgProvider) {
32+
// Not our response; ignore and wait for the matching one
33+
return
34+
}
35+
2336
clearTimeout(timeout)
2437
cleanup()
2538

@@ -32,7 +45,18 @@ const getRouterModels = async () =>
3245
}
3346

3447
window.addEventListener("message", handler)
35-
vscode.postMessage({ type: "requestRouterModels" })
48+
if (provider) {
49+
vscode.postMessage({ type: "requestRouterModels", values: { provider } })
50+
} else {
51+
vscode.postMessage({ type: "requestRouterModels" })
52+
}
3653
})
3754

38-
export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels })
55+
export const useRouterModels = (opts: UseRouterModelsOptions = {}) => {
56+
const provider = opts.provider || undefined
57+
return useQuery({
58+
queryKey: ["routerModels", provider || "all"],
59+
queryFn: () => getRouterModels(provider),
60+
enabled: opts.enabled !== false,
61+
})
62+
}

0 commit comments

Comments
 (0)