Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ describe("useSelectedModel", () => {
})

describe("loading and error states", () => {
it("should return loading state when router models are loading", () => {
it("should NOT set loading when router models are loading but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: undefined,
isLoading: true,
Expand All @@ -307,10 +307,11 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isLoading).toBe(true)
// With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false
expect(result.current.isLoading).toBe(false)
})

it("should return loading state when open router model providers are loading", () => {
it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} },
isLoading: false,
Expand All @@ -326,10 +327,11 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isLoading).toBe(true)
// With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false
expect(result.current.isLoading).toBe(false)
})

it("should return error state when either hook has an error", () => {
it("should NOT set error when hooks error but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: undefined,
isLoading: false,
Expand All @@ -345,7 +347,8 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isError).toBe(true)
// Error from gated routerModels should not bubble for static provider default
expect(result.current.isError).toBe(false)
})
})

Expand Down
32 changes: 29 additions & 3 deletions webview-ui/src/components/ui/hooks/useRouterModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"

import { vscode } from "@src/utils/vscode"

const getRouterModels = async () =>
type UseRouterModelsOptions = {
providers?: string[] // subset filter (e.g. ["roo"])
enabled?: boolean // gate fetching entirely
}

let __routerModelsRequestCount = 0

const getRouterModels = async (providers?: string[]) =>
new Promise<RouterModels>((resolve, reject) => {
const requestId = ++__routerModelsRequestCount
const cleanup = () => {
window.removeEventListener("message", handler)
}
Expand All @@ -24,6 +32,10 @@ const getRouterModels = async () =>
cleanup()

if (message.routerModels) {
const keys = Object.keys(message.routerModels || {})
console.debug(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this debug logging?

`[useRouterModels] response #${requestId} providers=${JSON.stringify(providers || "all")} keys=${keys.join(",")}`,
)
resolve(message.routerModels)
} else {
reject(new Error("No router models in response"))
Expand All @@ -32,7 +44,21 @@ const getRouterModels = async () =>
}

window.addEventListener("message", handler)
vscode.postMessage({ type: "requestRouterModels" })
console.debug(
`[useRouterModels] request #${requestId} providers=${JSON.stringify(providers && providers.length ? providers : "all")}`,
)
if (providers && providers.length > 0) {
vscode.postMessage({ type: "requestRouterModels", values: { providers } })
} else {
vscode.postMessage({ type: "requestRouterModels" })
}
})

export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels })
export const useRouterModels = (opts: UseRouterModelsOptions = {}) => {
const providers = opts.providers && opts.providers.length ? [...opts.providers] : undefined
return useQuery({
queryKey: ["routerModels", providers?.slice().sort().join(",") || "all"],
queryFn: () => getRouterModels(providers),
enabled: opts.enabled !== false,
})
}
60 changes: 44 additions & 16 deletions webview-ui/src/components/ui/hooks/useSelectedModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,56 @@ import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
import { useLmStudioModels } from "./useLmStudioModels"
import { useOllamaModels } from "./useOllamaModels"

const DYNAMIC_ROUTER_PROVIDERS = new Set<ProviderName>([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we’re going to forget to update this - isn’t this already in the types somewhere?

"openrouter",
"vercel-ai-gateway",
"litellm",
"deepinfra",
"io-intelligence",
"requesty",
"unbound",
"glama",
"roo",
])

export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
const provider = apiConfiguration?.apiProvider || "anthropic"
const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined
const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined
const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined

const routerModels = useRouterModels()
// Only fetch router models for dynamic router providers we actually need
const shouldFetchRouterModels = DYNAMIC_ROUTER_PROVIDERS.has(provider as ProviderName)
const routerModels = useRouterModels({
providers: shouldFetchRouterModels ? [provider] : undefined,
enabled: shouldFetchRouterModels, // disable entirely for static providers
})

const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
const lmStudioModels = useLmStudioModels(lmStudioModelId)
const ollamaModels = useOllamaModels(ollamaModelId)

// Compute readiness only for the data actually needed for the selected provider
const needRouterModels = shouldFetchRouterModels
const needOpenRouterProviders = provider === "openrouter"
const needLmStudio = typeof lmStudioModelId !== "undefined"
const needOllama = typeof ollamaModelId !== "undefined"

const isReady =
(!needLmStudio || typeof lmStudioModels.data !== "undefined") &&
(!needOllama || typeof ollamaModels.data !== "undefined") &&
(!needRouterModels || typeof routerModels.data !== "undefined") &&
(!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined")

const { id, info } =
apiConfiguration &&
(typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") &&
(typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") &&
typeof routerModels.data !== "undefined" &&
typeof openRouterModelProviders.data !== "undefined"
apiConfiguration && isReady
? getSelectedModel({
provider,
apiConfiguration,
routerModels: routerModels.data,
openRouterModelProviders: openRouterModelProviders.data,
lmStudioModels: lmStudioModels.data,
ollamaModels: ollamaModels.data,
routerModels: (routerModels.data || ({} as RouterModels)) as RouterModels,
openRouterModelProviders: (openRouterModelProviders.data || {}) as Record<string, ModelInfo>,
lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined,
ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined,
})
: { id: anthropicDefaultModelId, info: undefined }

Expand All @@ -99,13 +125,15 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
id,
info,
isLoading:
routerModels.isLoading ||
openRouterModelProviders.isLoading ||
(apiConfiguration?.lmStudioModelId && lmStudioModels!.isLoading),
(needRouterModels && routerModels.isLoading) ||
(needOpenRouterProviders && openRouterModelProviders.isLoading) ||
(needLmStudio && lmStudioModels!.isLoading) ||
(needOllama && ollamaModels!.isLoading),
isError:
routerModels.isError ||
openRouterModelProviders.isError ||
(apiConfiguration?.lmStudioModelId && lmStudioModels!.isError),
(needRouterModels && routerModels.isError) ||
(needOpenRouterProviders && openRouterModelProviders.isError) ||
(needLmStudio && lmStudioModels!.isError) ||
(needOllama && ollamaModels!.isError),
}
}

Expand Down