Skip to content

Commit 66e3f28

Browse files
committed
fix: Gate Roo router model fetches to reduce webview memory usage
- Only fetch Roo models when Roo is the active provider - Gate auth-driven requestRooModels by current provider - Add provider filtering support in backend handler - Disable router model fetches for static providers - Add tests for auth gating and provider filtering Resolves memory growth issue by minimizing unnecessary router model traffic
1 parent ff0c65a commit 66e3f28

File tree

7 files changed

+372
-65
lines changed

7 files changed

+372
-65
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import { describe, it, expect, vi, beforeEach } from "vitest"
2+
import { webviewMessageHandler } from "../webviewMessageHandler"
3+
import type { ClineProvider } from "../ClineProvider"
4+
5+
// Mock vscode (minimal)
6+
vi.mock("vscode", () => ({
7+
window: {
8+
showErrorMessage: vi.fn(),
9+
showWarningMessage: vi.fn(),
10+
showInformationMessage: vi.fn(),
11+
},
12+
workspace: {
13+
workspaceFolders: undefined,
14+
getConfiguration: vi.fn(() => ({
15+
get: vi.fn(),
16+
update: vi.fn(),
17+
})),
18+
},
19+
env: {
20+
clipboard: { writeText: vi.fn() },
21+
openExternal: vi.fn(),
22+
},
23+
commands: {
24+
executeCommand: vi.fn(),
25+
},
26+
Uri: {
27+
parse: vi.fn((s: string) => ({ toString: () => s })),
28+
file: vi.fn((p: string) => ({ fsPath: p })),
29+
},
30+
ConfigurationTarget: {
31+
Global: 1,
32+
Workspace: 2,
33+
WorkspaceFolder: 3,
34+
},
35+
}))
36+
37+
// Mock modelCache getModels/flushModels used by the handler
38+
const getModelsMock = vi.fn()
39+
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
40+
getModels: (...args: any[]) => getModelsMock(...args),
41+
flushModels: vi.fn(),
42+
}))
43+
44+
describe("webviewMessageHandler - requestRouterModels providers filter", () => {
45+
let mockProvider: ClineProvider & {
46+
postMessageToWebview: ReturnType<typeof vi.fn>
47+
getState: ReturnType<typeof vi.fn>
48+
contextProxy: any
49+
log: ReturnType<typeof vi.fn>
50+
}
51+
52+
beforeEach(() => {
53+
vi.clearAllMocks()
54+
55+
mockProvider = {
56+
// Only methods used by this code path
57+
postMessageToWebview: vi.fn(),
58+
getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }),
59+
contextProxy: {
60+
getValue: vi.fn(),
61+
setValue: vi.fn(),
62+
globalStorageUri: { fsPath: "/mock/storage" },
63+
},
64+
log: vi.fn(),
65+
} as any
66+
67+
// Default mock: return distinct model maps per provider so we can verify keys
68+
getModelsMock.mockImplementation(async (options: any) => {
69+
switch (options?.provider) {
70+
case "roo":
71+
return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } }
72+
case "openrouter":
73+
return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } }
74+
case "requesty":
75+
return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } }
76+
case "deepinfra":
77+
return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } }
78+
case "glama":
79+
return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } }
80+
case "unbound":
81+
return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } }
82+
case "vercel-ai-gateway":
83+
return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } }
84+
case "io-intelligence":
85+
return { "io/model": { contextWindow: 8192, supportsPromptCache: false } }
86+
case "litellm":
87+
return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } }
88+
default:
89+
return {}
90+
}
91+
})
92+
})
93+
94+
it("fetches only requested provider when values.providers is present (['roo'])", async () => {
95+
await webviewMessageHandler(
96+
mockProvider as any,
97+
{
98+
type: "requestRouterModels",
99+
values: { providers: ["roo"] },
100+
} as any,
101+
)
102+
103+
// Should post a single routerModels message
104+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
105+
expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }),
106+
)
107+
108+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
109+
(c: any[]) => c[0]?.type === "routerModels",
110+
)
111+
expect(call).toBeTruthy()
112+
const payload = call[0]
113+
const routerModels = payload.routerModels as Record<string, Record<string, any>>
114+
115+
// Only "roo" key should be present
116+
const keys = Object.keys(routerModels)
117+
expect(keys).toEqual(["roo"])
118+
expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet")
119+
120+
// getModels should have been called exactly once for roo
121+
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
122+
expect(providersCalled).toEqual(["roo"])
123+
})
124+
125+
it("defaults to aggregate fetching when no providers filter is sent", async () => {
126+
await webviewMessageHandler(
127+
mockProvider as any,
128+
{
129+
type: "requestRouterModels",
130+
} as any,
131+
)
132+
133+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
134+
(c: any[]) => c[0]?.type === "routerModels",
135+
)
136+
expect(call).toBeTruthy()
137+
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
138+
139+
// Aggregate handler initializes many known routers - ensure a few expected keys exist
140+
expect(routerModels).toHaveProperty("openrouter")
141+
expect(routerModels).toHaveProperty("roo")
142+
expect(routerModels).toHaveProperty("requesty")
143+
})
144+
145+
it("supports filtering another single provider (['openrouter'])", async () => {
146+
await webviewMessageHandler(
147+
mockProvider as any,
148+
{
149+
type: "requestRouterModels",
150+
values: { providers: ["openrouter"] },
151+
} as any,
152+
)
153+
154+
const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
155+
(c: any[]) => c[0]?.type === "routerModels",
156+
)
157+
expect(call).toBeTruthy()
158+
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
159+
const keys = Object.keys(routerModels)
160+
161+
expect(keys).toEqual(["openrouter"])
162+
expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5")
163+
164+
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
165+
expect(providersCalled).toEqual(["openrouter"])
166+
})
167+
})

src/core/webview/webviewMessageHandler.ts

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -757,20 +757,38 @@ export const webviewMessageHandler = async (
757757
case "requestRouterModels":
758758
const { apiConfiguration } = await provider.getState()
759759

760-
const routerModels: Record<RouterName, ModelRecord> = {
761-
openrouter: {},
762-
"vercel-ai-gateway": {},
763-
huggingface: {},
764-
litellm: {},
765-
deepinfra: {},
766-
"io-intelligence": {},
767-
requesty: {},
768-
unbound: {},
769-
glama: {},
770-
ollama: {},
771-
lmstudio: {},
772-
roo: {},
773-
}
760+
// Optional providers filter coming from the webview
761+
const providersFilterRaw = Array.isArray(message?.values?.providers) ? message.values.providers : undefined
762+
const requestedProviders = providersFilterRaw
763+
?.filter((p: unknown) => typeof p === "string")
764+
.map((p: string) => {
765+
try {
766+
return toRouterName(p)
767+
} catch {
768+
return undefined
769+
}
770+
})
771+
.filter((p): p is RouterName => !!p)
772+
773+
const hasFilter = !!requestedProviders && requestedProviders.length > 0
774+
const requestedSet = new Set<RouterName>(requestedProviders || [])
775+
776+
const routerModels: Record<RouterName, ModelRecord> = hasFilter
777+
? ({} as Record<RouterName, ModelRecord>)
778+
: {
779+
openrouter: {},
780+
"vercel-ai-gateway": {},
781+
huggingface: {},
782+
litellm: {},
783+
deepinfra: {},
784+
"io-intelligence": {},
785+
requesty: {},
786+
unbound: {},
787+
glama: {},
788+
ollama: {},
789+
lmstudio: {},
790+
roo: {},
791+
}
774792

775793
const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
776794
try {
@@ -785,7 +803,8 @@ export const webviewMessageHandler = async (
785803
}
786804
}
787805

788-
const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [
806+
// Base candidates (only those handled by this aggregate fetcher)
807+
const candidates: { key: RouterName; options: GetModelsOptions }[] = [
789808
{ key: "openrouter", options: { provider: "openrouter" } },
790809
{
791810
key: "requesty",
@@ -818,29 +837,28 @@ export const webviewMessageHandler = async (
818837
},
819838
]
820839

821-
// Add IO Intelligence if API key is provided.
822-
const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey
823-
824-
if (ioIntelligenceApiKey) {
825-
modelFetchPromises.push({
840+
// IO Intelligence is conditional on api key
841+
if (apiConfiguration.ioIntelligenceApiKey) {
842+
candidates.push({
826843
key: "io-intelligence",
827-
options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey },
844+
options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey },
828845
})
829846
}
830847

831-
// Don't fetch Ollama and LM Studio models by default anymore.
832-
// They have their own specific handlers: requestOllamaModels and requestLmStudioModels.
833-
848+
// LiteLLM is conditional on baseUrl+apiKey
834849
const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey
835850
const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl
836851

837852
if (litellmApiKey && litellmBaseUrl) {
838-
modelFetchPromises.push({
853+
candidates.push({
839854
key: "litellm",
840855
options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },
841856
})
842857
}
843858

859+
// Apply providers filter (if any)
860+
const modelFetchPromises = candidates.filter(({ key }) => (!hasFilter ? true : requestedSet.has(key)))
861+
844862
const results = await Promise.allSettled(
845863
modelFetchPromises.map(async ({ key, options }) => {
846864
const models = await safeGetModels(options)
@@ -854,18 +872,7 @@ export const webviewMessageHandler = async (
854872
if (result.status === "fulfilled") {
855873
routerModels[routerName] = result.value.models
856874

857-
// Ollama and LM Studio settings pages still need these events.
858-
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
859-
provider.postMessageToWebview({
860-
type: "ollamaModels",
861-
ollamaModels: result.value.models,
862-
})
863-
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
864-
provider.postMessageToWebview({
865-
type: "lmStudioModels",
866-
lmStudioModels: result.value.models,
867-
})
868-
}
875+
// Ollama and LM Studio settings pages still need these events. They are not fetched here.
869876
} else {
870877
// Handle rejection: Post a specific error message for this provider.
871878
const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)

webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ describe("useSelectedModel", () => {
291291
})
292292

293293
describe("loading and error states", () => {
294-
it("should return loading state when router models are loading", () => {
294+
it("should NOT set loading when router models are loading but provider is static (anthropic)", () => {
295295
mockUseRouterModels.mockReturnValue({
296296
data: undefined,
297297
isLoading: true,
@@ -307,10 +307,11 @@ describe("useSelectedModel", () => {
307307
const wrapper = createWrapper()
308308
const { result } = renderHook(() => useSelectedModel(), { wrapper })
309309

310-
expect(result.current.isLoading).toBe(true)
310+
// With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false
311+
expect(result.current.isLoading).toBe(false)
311312
})
312313

313-
it("should return loading state when open router model providers are loading", () => {
314+
it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => {
314315
mockUseRouterModels.mockReturnValue({
315316
data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} },
316317
isLoading: false,
@@ -326,10 +327,11 @@ describe("useSelectedModel", () => {
326327
const wrapper = createWrapper()
327328
const { result } = renderHook(() => useSelectedModel(), { wrapper })
328329

329-
expect(result.current.isLoading).toBe(true)
330+
// With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false
331+
expect(result.current.isLoading).toBe(false)
330332
})
331333

332-
it("should return error state when either hook has an error", () => {
334+
it("should NOT set error when hooks error but provider is static (anthropic)", () => {
333335
mockUseRouterModels.mockReturnValue({
334336
data: undefined,
335337
isLoading: false,
@@ -345,7 +347,8 @@ describe("useSelectedModel", () => {
345347
const wrapper = createWrapper()
346348
const { result } = renderHook(() => useSelectedModel(), { wrapper })
347349

348-
expect(result.current.isError).toBe(true)
350+
// Error from gated routerModels should not bubble for static provider default
351+
expect(result.current.isError).toBe(false)
349352
})
350353
})
351354

webview-ui/src/components/ui/hooks/useRouterModels.ts

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,16 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"
55

66
import { vscode } from "@src/utils/vscode"
77

8-
const getRouterModels = async () =>
8+
type UseRouterModelsOptions = {
9+
providers?: string[] // subset filter (e.g. ["roo"])
10+
enabled?: boolean // gate fetching entirely
11+
}
12+
13+
let __routerModelsRequestCount = 0
14+
15+
const getRouterModels = async (providers?: string[]) =>
916
new Promise<RouterModels>((resolve, reject) => {
17+
const requestId = ++__routerModelsRequestCount
1018
const cleanup = () => {
1119
window.removeEventListener("message", handler)
1220
}
@@ -24,6 +32,10 @@ const getRouterModels = async () =>
2432
cleanup()
2533

2634
if (message.routerModels) {
35+
const keys = Object.keys(message.routerModels || {})
36+
console.debug(
37+
`[useRouterModels] response #${requestId} providers=${JSON.stringify(providers || "all")} keys=${keys.join(",")}`,
38+
)
2739
resolve(message.routerModels)
2840
} else {
2941
reject(new Error("No router models in response"))
@@ -32,7 +44,21 @@ const getRouterModels = async () =>
3244
}
3345

3446
window.addEventListener("message", handler)
35-
vscode.postMessage({ type: "requestRouterModels" })
47+
console.debug(
48+
`[useRouterModels] request #${requestId} providers=${JSON.stringify(providers && providers.length ? providers : "all")}`,
49+
)
50+
if (providers && providers.length > 0) {
51+
vscode.postMessage({ type: "requestRouterModels", values: { providers } })
52+
} else {
53+
vscode.postMessage({ type: "requestRouterModels" })
54+
}
3655
})
3756

38-
export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels })
57+
export const useRouterModels = (opts: UseRouterModelsOptions = {}) => {
58+
const providers = opts.providers && opts.providers.length ? [...opts.providers] : undefined
59+
return useQuery({
60+
queryKey: ["routerModels", providers?.slice().sort().join(",") || "all"],
61+
queryFn: () => getRouterModels(providers),
62+
enabled: opts.enabled !== false,
63+
})
64+
}

0 commit comments

Comments
 (0)