diff --git a/src/api/providers/fetchers/__tests__/litellm.spec.ts b/src/api/providers/fetchers/__tests__/litellm.spec.ts index a93c21ee1b0d..aa7257f7e7b9 100644 --- a/src/api/providers/fetchers/__tests__/litellm.spec.ts +++ b/src/api/providers/fetchers/__tests__/litellm.spec.ts @@ -35,7 +35,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -56,7 +56,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -77,7 +77,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -98,7 +98,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -119,7 +119,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -140,7 +140,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -161,7 +161,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -213,7 +213,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) expect(result).toEqual({ @@ -254,7 +254,7 @@ describe("getLiteLLMModels", () => { "Content-Type": "application/json", ...DEFAULT_HEADERS, }, - timeout: 5000, + signal: undefined, }) }) @@ -381,7 +381,7 @@ describe("getLiteLLMModels", () => { expect(mockedAxios.get).toHaveBeenCalledWith( "http://localhost:4000/v1/model/info", expect.objectContaining({ - timeout: 5000, + signal: undefined, }), ) }) diff --git a/src/api/providers/fetchers/__tests__/lmstudio.test.ts b/src/api/providers/fetchers/__tests__/lmstudio.test.ts index a1f06d2e251e..d33c991a1f5f 100644 --- a/src/api/providers/fetchers/__tests__/lmstudio.test.ts +++ b/src/api/providers/fetchers/__tests__/lmstudio.test.ts @@ -113,7 +113,7 @@ describe("LMStudio Fetcher", () => { const result = await getLMStudioModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`, { signal: undefined }) expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1) expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl }) expect(mockListDownloadedModels).toHaveBeenCalledTimes(1) @@ -133,7 +133,7 @@ describe("LMStudio Fetcher", () => { const result = await getLMStudioModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`, { signal: undefined }) expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1) expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl }) expect(mockListDownloadedModels).toHaveBeenCalledTimes(1) @@ -373,7 +373,7 @@ describe("LMStudio Fetcher", () => { await getLMStudioModels("") - expect(mockedAxios.get).toHaveBeenCalledWith(`${defaultBaseUrl}/v1/models`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${defaultBaseUrl}/v1/models`, { signal: undefined }) expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: defaultLmsUrl }) }) @@ -385,7 +385,7 @@ describe("LMStudio Fetcher", () => { await getLMStudioModels(httpsBaseUrl) - expect(mockedAxios.get).toHaveBeenCalledWith(`${httpsBaseUrl}/v1/models`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${httpsBaseUrl}/v1/models`, { signal: undefined }) expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: wssLmsUrl }) }) @@ -407,7 +407,7 @@ describe("LMStudio Fetcher", () => { const result = await getLMStudioModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`, { signal: undefined }) expect(MockedLMStudioClientConstructor).not.toHaveBeenCalled() expect(mockListLoaded).not.toHaveBeenCalled() expect(consoleErrorSpy).toHaveBeenCalledWith( @@ -426,7 +426,7 @@ describe("LMStudio Fetcher", () => { const result = await getLMStudioModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`, { signal: undefined }) expect(MockedLMStudioClientConstructor).not.toHaveBeenCalled() expect(mockListLoaded).not.toHaveBeenCalled() expect(consoleInfoSpy).toHaveBeenCalledWith(`Error connecting to LMStudio at ${baseUrl}`) diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 2a72ef1cc5f8..719d553395cc 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -69,7 +69,11 @@ describe("getModels with new GetModelsOptions", () => { baseUrl: "http://localhost:4000", }) - expect(mockGetLiteLLMModels).toHaveBeenCalledWith("test-api-key", "http://localhost:4000") + expect(mockGetLiteLLMModels).toHaveBeenCalledWith( + "test-api-key", + "http://localhost:4000", + expect.any(AbortSignal), + ) expect(result).toEqual(mockModels) }) @@ -103,7 +107,7 @@ describe("getModels with new GetModelsOptions", () => { const result = await getModels({ provider: "requesty", apiKey: DUMMY_REQUESTY_KEY }) - expect(mockGetRequestyModels).toHaveBeenCalledWith(undefined, DUMMY_REQUESTY_KEY) + expect(mockGetRequestyModels).toHaveBeenCalledWith(undefined, DUMMY_REQUESTY_KEY, expect.any(AbortSignal)) expect(result).toEqual(mockModels) }) @@ -137,7 +141,7 @@ describe("getModels with new GetModelsOptions", () => { const result = await getModels({ provider: "unbound", apiKey: DUMMY_UNBOUND_KEY }) - expect(mockGetUnboundModels).toHaveBeenCalledWith(DUMMY_UNBOUND_KEY) + expect(mockGetUnboundModels).toHaveBeenCalledWith(DUMMY_UNBOUND_KEY, expect.any(AbortSignal)) expect(result).toEqual(mockModels) }) diff --git a/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts b/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts index 30ad2f41d5bf..b938c0641c46 100644 --- a/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts +++ b/src/api/providers/fetchers/__tests__/vercel-ai-gateway.spec.ts @@ -77,7 +77,9 @@ describe("Vercel AI Gateway Fetchers", () => { const models = await getVercelAiGatewayModels() - expect(mockedAxios.get).toHaveBeenCalledWith("https://ai-gateway.vercel.sh/v1/models") + expect(mockedAxios.get).toHaveBeenCalledWith("https://ai-gateway.vercel.sh/v1/models", { + signal: undefined, + }) expect(Object.keys(models)).toHaveLength(2) // Only language models expect(models["anthropic/claude-sonnet-4"]).toBeDefined() expect(models["anthropic/claude-3.5-haiku"]).toBeDefined() diff --git a/src/api/providers/fetchers/deepinfra.ts b/src/api/providers/fetchers/deepinfra.ts index f38daff8224f..5d817cbb7f08 100644 --- a/src/api/providers/fetchers/deepinfra.ts +++ b/src/api/providers/fetchers/deepinfra.ts @@ -35,6 +35,7 @@ const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSch export async function getDeepInfraModels( apiKey?: string, baseUrl: string = "https://api.deepinfra.com/v1/openai", + signal?: AbortSignal, ): Promise> { const headers: Record = { ...DEFAULT_HEADERS } if (apiKey) headers["Authorization"] = `Bearer ${apiKey}` @@ -42,7 +43,7 @@ export async function getDeepInfraModels( const url = `${baseUrl.replace(/\/$/, "")}/models` const models: Record = {} - const response = await axios.get(url, { headers }) + const response = await axios.get(url, { headers, signal }) const parsed = DeepInfraModelsResponseSchema.safeParse(response.data) const data = parsed.success ? parsed.data.data : response.data?.data || [] diff --git a/src/api/providers/fetchers/glama.ts b/src/api/providers/fetchers/glama.ts index ae36c751fb82..f451cd9348d1 100644 --- a/src/api/providers/fetchers/glama.ts +++ b/src/api/providers/fetchers/glama.ts @@ -4,11 +4,11 @@ import type { ModelInfo } from "@roo-code/types" import { parseApiPrice } from "../../../shared/cost" -export async function getGlamaModels(): Promise> { +export async function getGlamaModels(signal?: AbortSignal): Promise> { const models: Record = {} try { - const response = await axios.get("https://glama.ai/api/gateway/v1/models") + const response = await axios.get("https://glama.ai/api/gateway/v1/models", { signal }) const rawModels = response.data for (const rawModel of rawModels) { diff --git a/src/api/providers/fetchers/huggingface.ts b/src/api/providers/fetchers/huggingface.ts index 1a7a995bc6ef..b6b1cd28e655 100644 --- a/src/api/providers/fetchers/huggingface.ts +++ b/src/api/providers/fetchers/huggingface.ts @@ -107,7 +107,7 @@ function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFacePr * @returns A promise that resolves to a record of model IDs to model info * @throws Will throw an error if the request fails */ -export async function getHuggingFaceModels(): Promise { +export async function getHuggingFaceModels(signal?: AbortSignal): Promise { const now = Date.now() if (cache && now - cache.timestamp < HUGGINGFACE_CACHE_DURATION) { @@ -128,7 +128,7 @@ export async function getHuggingFaceModels(): Promise { Pragma: "no-cache", "Cache-Control": "no-cache", }, - timeout: 10000, + signal, }) const result = huggingFaceApiResponseSchema.safeParse(response.data) @@ -236,7 +236,7 @@ export async function getHuggingFaceModelsWithMetadata(): Promise1 */ -export async function getIOIntelligenceModels(apiKey?: string): Promise { +export async function getIOIntelligenceModels(apiKey?: string, signal?: AbortSignal): Promise { const now = Date.now() if (cache && now - cache.timestamp < IO_INTELLIGENCE_CACHE_DURATION) { @@ -108,7 +108,7 @@ export async function getIOIntelligenceModels(apiKey?: string): Promise { +export async function getLiteLLMModels(apiKey: string, baseUrl: string, signal?: AbortSignal): Promise { try { const headers: Record = { "Content-Type": "application/json", @@ -27,8 +27,7 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise // Normalize the pathname by removing trailing slashes and multiple slashes urlObj.pathname = urlObj.pathname.replace(/\/+$/, "").replace(/\/+/g, "/") + "/v1/model/info" const url = urlObj.href - // Added timeout to prevent indefinite hanging - const response = await axios.get(url, { headers, timeout: 5000 }) + const response = await axios.get(url, { headers, signal }) const models: ModelRecord = {} // Process the model info from the response diff --git a/src/api/providers/fetchers/lmstudio.ts b/src/api/providers/fetchers/lmstudio.ts index de3f804c28ae..f7253f938056 100644 --- a/src/api/providers/fetchers/lmstudio.ts +++ b/src/api/providers/fetchers/lmstudio.ts @@ -49,7 +49,10 @@ export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelIn return modelInfo } -export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Promise> { +export async function getLMStudioModels( + baseUrl = "http://localhost:1234", + signal?: AbortSignal, +): Promise> { // clear the set of models that have full details loaded modelsWithLoadedDetails.clear() // clearing the input can leave an empty string; use the default in that case @@ -66,7 +69,7 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom // test the connection to LM Studio first // errors will be caught further down - await axios.get(`${baseUrl}/v1/models`) + await axios.get(`${baseUrl}/v1/models`, { signal }) const client = new LMStudioClient({ baseUrl: lmsUrl }) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 55b5bc3a3047..a60044d2d582 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -3,8 +3,6 @@ import fs from "fs/promises" import NodeCache from "node-cache" -import type { ProviderName } from "@roo-code/types" - import { safeWriteJson } from "../../../utils/safeWriteJson" import { ContextProxy } from "../../../core/config/ContextProxy" @@ -56,83 +54,81 @@ async function readModels(router: RouterName): Promise export const getModels = async (options: GetModelsOptions): Promise => { const { provider } = options - let models = getModelsFromCache(provider) - - if (models) { - return models + // 1) Try memory cache + const cached = getModelsFromCache(provider) + if (cached) { + return cached } + // 2) Try file cache snapshot try { - switch (provider) { - case "openrouter": - models = await getOpenRouterModels() - break - case "requesty": - // Requesty models endpoint requires an API key for per-user custom policies. - models = await getRequestyModels(options.baseUrl, options.apiKey) - break - case "glama": - models = await getGlamaModels() - break - case "unbound": - // Unbound models endpoint requires an API key to fetch application specific models. - models = await getUnboundModels(options.apiKey) - break - case "litellm": - // Type safety ensures apiKey and baseUrl are always provided for LiteLLM. - models = await getLiteLLMModels(options.apiKey, options.baseUrl) - break - case "ollama": - models = await getOllamaModels(options.baseUrl, options.apiKey) - break - case "lmstudio": - models = await getLMStudioModels(options.baseUrl) - break - case "deepinfra": - models = await getDeepInfraModels(options.apiKey, options.baseUrl) - break - case "io-intelligence": - models = await getIOIntelligenceModels(options.apiKey) - break - case "vercel-ai-gateway": - models = await getVercelAiGatewayModels() - break - case "huggingface": - models = await getHuggingFaceModels() - break - case "roo": { - // Roo Code Cloud provider requires baseUrl and optional apiKey - const rooBaseUrl = - options.baseUrl ?? process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy" - models = await getRooModels(rooBaseUrl, options.apiKey) - break - } - default: { - // Ensures router is exhaustively checked if RouterName is a strict union. - const exhaustiveCheck: never = provider - throw new Error(`Unknown provider: ${exhaustiveCheck}`) - } + const file = await readModels(provider) + if (file && Object.keys(file).length > 0) { + memoryCache.set(provider, file) + return file + } + } catch { + // ignore file read errors; fall through to network fetch + } + + // 3) Network fetch + const signal = AbortSignal.timeout(30_000) + let models: ModelRecord = {} + + switch (provider) { + case "openrouter": + models = await getOpenRouterModels(undefined, signal) + break + case "requesty": + models = await getRequestyModels(options.baseUrl, options.apiKey, signal) + break + case "glama": + models = await getGlamaModels(signal) + break + case "unbound": + models = await getUnboundModels(options.apiKey, signal) + break + case "litellm": + models = await getLiteLLMModels(options.apiKey as string, options.baseUrl as string, signal) + break + case "ollama": + models = await getOllamaModels(options.baseUrl, options.apiKey, signal) + break + case "lmstudio": + models = await getLMStudioModels(options.baseUrl, signal) + break + case "deepinfra": + models = await getDeepInfraModels(options.apiKey, options.baseUrl, signal) + break + case "io-intelligence": + models = await getIOIntelligenceModels(options.apiKey, signal) + break + case "vercel-ai-gateway": + models = await getVercelAiGatewayModels(undefined, signal) + break + case "huggingface": + models = await getHuggingFaceModels(signal) + break + case "roo": { + const rooBaseUrl = options.baseUrl ?? process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy" + models = await getRooModels(rooBaseUrl, options.apiKey, signal) + break } + default: { + throw new Error(`Unknown provider: ${provider}`) + } + } - // Cache the fetched models (even if empty, to signify a successful fetch with no models). - memoryCache.set(provider, models) + memoryCache.set(provider, models) - await writeModels(provider, models).catch((err) => - console.error(`[getModels] Error writing ${provider} models to file cache:`, err), + await writeModels(provider, models).catch((err) => { + console.error( + `[modelCache] Error writing ${provider} to file cache after network fetch:`, + err instanceof Error ? err.message : String(err), ) + }) - try { - models = await readModels(provider) - } catch (error) { - console.error(`[getModels] error reading ${provider} models from file cache`, error) - } - return models || {} - } catch (error) { - // Log the error and re-throw it so the caller can handle it (e.g., show a UI message). - console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error) - - throw error // Re-throw the original error to be handled by the caller. - } + return models || {} } /** @@ -144,6 +140,6 @@ export const flushModels = async (router: RouterName) => { memoryCache.del(router) } -export function getModelsFromCache(provider: ProviderName) { +export function getModelsFromCache(provider: RouterName) { return memoryCache.get(provider) } diff --git a/src/api/providers/fetchers/modelEndpointCache.ts b/src/api/providers/fetchers/modelEndpointCache.ts index 256ae8404800..49322a6a7de7 100644 --- a/src/api/providers/fetchers/modelEndpointCache.ts +++ b/src/api/providers/fetchers/modelEndpointCache.ts @@ -46,37 +46,52 @@ export const getModelEndpoints = async ({ } const key = getCacheKey(router, modelId) - let modelProviders = memoryCache.get(key) - if (modelProviders) { - // console.log(`[getModelProviders] NodeCache hit for ${key} -> ${Object.keys(modelProviders).length}`) - return modelProviders + // 1) Try memory cache + const cached = memoryCache.get(key) + if (cached) { + return cached + } + + // 2) Try file cache snapshot + try { + const file = await readModelEndpoints(key) + if (file && Object.keys(file).length > 0) { + memoryCache.set(key, file) + return file + } + } catch { + // ignore file read errors; fall through to network fetch } - modelProviders = await getOpenRouterModelEndpoints(modelId) + // 3) Network fetch + const signal = AbortSignal.timeout(30_000) + let modelProviders: ModelRecord = {} + + modelProviders = await getOpenRouterModelEndpoints(modelId, undefined, signal) if (Object.keys(modelProviders).length > 0) { - // console.log(`[getModelProviders] API fetch for ${key} -> ${Object.keys(modelProviders).length}`) memoryCache.set(key, modelProviders) try { await writeModelEndpoints(key, modelProviders) - // console.log(`[getModelProviders] wrote ${key} endpoints to file cache`) } catch (error) { - console.error(`[getModelProviders] error writing ${key} endpoints to file cache`, error) + console.error( + `[endpointCache] Error writing ${key} to file cache after network fetch:`, + error instanceof Error ? error.message : String(error), + ) } return modelProviders } + // Fallback to file cache if network returned empty (rare) try { - modelProviders = await readModelEndpoints(router) - // console.log(`[getModelProviders] read ${key} endpoints from file cache`) - } catch (error) { - console.error(`[getModelProviders] error reading ${key} endpoints from file cache`, error) + const file = await readModelEndpoints(key) + return file ?? {} + } catch { + return {} } - - return modelProviders ?? {} } export const flushModelProviders = async (router: RouterName, modelId: string) => diff --git a/src/api/providers/fetchers/ollama.ts b/src/api/providers/fetchers/ollama.ts index 4bf43b6faf3c..bfb75b6b5e93 100644 --- a/src/api/providers/fetchers/ollama.ts +++ b/src/api/providers/fetchers/ollama.ts @@ -56,6 +56,7 @@ export const parseOllamaModel = (rawModel: OllamaModelInfoResponse): ModelInfo = export async function getOllamaModels( baseUrl = "http://localhost:11434", apiKey?: string, + signal?: AbortSignal, ): Promise> { const models: Record = {} @@ -73,7 +74,7 @@ export async function getOllamaModels( headers["Authorization"] = `Bearer ${apiKey}` } - const response = await axios.get(`${baseUrl}/api/tags`, { headers }) + const response = await axios.get(`${baseUrl}/api/tags`, { headers, signal }) const parsedResponse = OllamaModelsResponseSchema.safeParse(response.data) let modelInfoPromises = [] @@ -86,7 +87,7 @@ export async function getOllamaModels( { model: ollamaModel.model, }, - { headers }, + { headers, signal }, ) .then((ollamaModelInfo) => { models[ollamaModel.name] = parseOllamaModel(ollamaModelInfo.data) diff --git a/src/api/providers/fetchers/openrouter.ts b/src/api/providers/fetchers/openrouter.ts index b546c40a3cfc..8d8659b3f19c 100644 --- a/src/api/providers/fetchers/openrouter.ts +++ b/src/api/providers/fetchers/openrouter.ts @@ -94,12 +94,15 @@ type OpenRouterModelEndpointsResponse = z.infer> { +export async function getOpenRouterModels( + options?: ApiHandlerOptions, + signal?: AbortSignal, +): Promise> { const models: Record = {} const baseURL = options?.openRouterBaseUrl || "https://openrouter.ai/api/v1" try { - const response = await axios.get(`${baseURL}/models`) + const response = await axios.get(`${baseURL}/models`, { signal }) const result = openRouterModelsResponseSchema.safeParse(response.data) const data = result.success ? result.data.data : response.data.data @@ -140,12 +143,15 @@ export async function getOpenRouterModels(options?: ApiHandlerOptions): Promise< export async function getOpenRouterModelEndpoints( modelId: string, options?: ApiHandlerOptions, + signal?: AbortSignal, ): Promise> { const models: Record = {} const baseURL = options?.openRouterBaseUrl || "https://openrouter.ai/api/v1" try { - const response = await axios.get(`${baseURL}/models/${modelId}/endpoints`) + const response = await axios.get(`${baseURL}/models/${modelId}/endpoints`, { + signal, + }) const result = openRouterModelEndpointsResponseSchema.safeParse(response.data) const data = result.success ? result.data.data : response.data.data diff --git a/src/api/providers/fetchers/requesty.ts b/src/api/providers/fetchers/requesty.ts index 64c7de668928..1fe0f1585736 100644 --- a/src/api/providers/fetchers/requesty.ts +++ b/src/api/providers/fetchers/requesty.ts @@ -5,7 +5,11 @@ import type { ModelInfo } from "@roo-code/types" import { parseApiPrice } from "../../../shared/cost" import { toRequestyServiceUrl } from "../../../shared/utils/requesty" -export async function getRequestyModels(baseUrl?: string, apiKey?: string): Promise> { +export async function getRequestyModels( + baseUrl?: string, + apiKey?: string, + signal?: AbortSignal, +): Promise> { const models: Record = {} try { @@ -18,7 +22,7 @@ export async function getRequestyModels(baseUrl?: string, apiKey?: string): Prom const resolvedBaseUrl = toRequestyServiceUrl(baseUrl) const modelsUrl = new URL("v1/models", resolvedBaseUrl) - const response = await axios.get(modelsUrl.toString(), { headers }) + const response = await axios.get(modelsUrl.toString(), { headers, signal }) const rawModels = response.data.data for (const rawModel of rawModels) { diff --git a/src/api/providers/fetchers/roo.ts b/src/api/providers/fetchers/roo.ts index 17aec4253b31..0dca037680be 100644 --- a/src/api/providers/fetchers/roo.ts +++ b/src/api/providers/fetchers/roo.ts @@ -13,7 +13,7 @@ import { DEFAULT_HEADERS } from "../constants" * @returns A promise that resolves to a record of model IDs to model info * @throws Will throw an error if the request fails or the response is not as expected. */ -export async function getRooModels(baseUrl: string, apiKey?: string): Promise { +export async function getRooModels(baseUrl: string, apiKey?: string, signal?: AbortSignal): Promise { try { const headers: Record = { "Content-Type": "application/json", @@ -29,87 +29,79 @@ export async function getRooModels(baseUrl: string, apiKey?: string): Promise controller.abort(), 10000) + const response = await fetch(url, { + headers, + signal, + }) - try { - const response = await fetch(url, { - headers, - signal: controller.signal, - }) - - if (!response.ok) { - throw new Error(`HTTP ${response.status}: ${response.statusText}`) - } + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } - const data = await response.json() - const models: ModelRecord = {} + const data = await response.json() + const models: ModelRecord = {} - // Validate response against schema - const parsed = RooModelsResponseSchema.safeParse(data) + // Validate response against schema + const parsed = RooModelsResponseSchema.safeParse(data) - if (!parsed.success) { - console.error("Error fetching Roo Code Cloud models: Unexpected response format", data) - console.error("Validation errors:", parsed.error.format()) - throw new Error("Failed to fetch Roo Code Cloud models: Unexpected response format.") - } + if (!parsed.success) { + console.error("Error fetching Roo Code Cloud models: Unexpected response format", data) + console.error("Validation errors:", parsed.error.format()) + throw new Error("Failed to fetch Roo Code Cloud models: Unexpected response format.") + } - // Process the validated model data - for (const model of parsed.data.data) { - const modelId = model.id - - if (!modelId) continue - - // Extract model data from the validated API response - // All required fields are guaranteed by the schema - const contextWindow = model.context_window - const maxTokens = model.max_tokens - const tags = model.tags || [] - const pricing = model.pricing - - // Determine if the model supports images based on tags - const supportsImages = tags.includes("vision") - - // Determine if the model supports reasoning effort based on tags - const supportsReasoningEffort = tags.includes("reasoning") - - // Determine if the model requires reasoning effort based on tags - const requiredReasoningEffort = tags.includes("reasoning-required") - - // Parse pricing (API returns strings, convert to numbers) - const inputPrice = parseApiPrice(pricing.input) - const outputPrice = parseApiPrice(pricing.output) - const cacheReadPrice = pricing.input_cache_read ? parseApiPrice(pricing.input_cache_read) : undefined - const cacheWritePrice = pricing.input_cache_write ? parseApiPrice(pricing.input_cache_write) : undefined - - models[modelId] = { - maxTokens, - contextWindow, - supportsImages, - supportsReasoningEffort, - requiredReasoningEffort, - supportsPromptCache: Boolean(cacheReadPrice !== undefined), - inputPrice, - outputPrice, - cacheWritesPrice: cacheWritePrice, - cacheReadsPrice: cacheReadPrice, - description: model.description || model.name, - deprecated: model.deprecated || false, - isFree: tags.includes("free"), - } + // Process the validated model data + for (const model of parsed.data.data) { + const modelId = model.id + + if (!modelId) continue + + // Extract model data from the validated API response + // All required fields are guaranteed by the schema + const contextWindow = model.context_window + const maxTokens = model.max_tokens + const tags = model.tags || [] + const pricing = model.pricing + + // Determine if the model supports images based on tags + const supportsImages = tags.includes("vision") + + // Determine if the model supports reasoning effort based on tags + const supportsReasoningEffort = tags.includes("reasoning") + + // Determine if the model requires reasoning effort based on tags + const requiredReasoningEffort = tags.includes("reasoning-required") + + // Parse pricing (API returns strings, convert to numbers) + const inputPrice = parseApiPrice(pricing.input) + const outputPrice = parseApiPrice(pricing.output) + const cacheReadPrice = pricing.input_cache_read ? parseApiPrice(pricing.input_cache_read) : undefined + const cacheWritePrice = pricing.input_cache_write ? parseApiPrice(pricing.input_cache_write) : undefined + + models[modelId] = { + maxTokens, + contextWindow, + supportsImages, + supportsReasoningEffort, + requiredReasoningEffort, + supportsPromptCache: Boolean(cacheReadPrice !== undefined), + inputPrice, + outputPrice, + cacheWritesPrice: cacheWritePrice, + cacheReadsPrice: cacheReadPrice, + description: model.description || model.name, + deprecated: model.deprecated || false, + isFree: tags.includes("free"), } - - return models - } finally { - clearTimeout(timeoutId) } + + return models } catch (error: any) { console.error("Error fetching Roo Code Cloud models:", error.message ? error.message : error) // Handle abort/timeout if (error.name === "AbortError") { - throw new Error("Failed to fetch Roo Code Cloud models: Request timed out after 10 seconds.") + throw new Error("Failed to fetch Roo Code Cloud models: Request timed out.") } // Handle fetch errors diff --git a/src/api/providers/fetchers/unbound.ts b/src/api/providers/fetchers/unbound.ts index 354c0fde58aa..a339a4da193c 100644 --- a/src/api/providers/fetchers/unbound.ts +++ b/src/api/providers/fetchers/unbound.ts @@ -2,7 +2,10 @@ import axios from "axios" import type { ModelInfo } from "@roo-code/types" -export async function getUnboundModels(apiKey?: string | null): Promise> { +export async function getUnboundModels( + apiKey?: string | null, + signal?: AbortSignal, +): Promise> { const models: Record = {} try { @@ -12,7 +15,7 @@ export async function getUnboundModels(apiKey?: string | null): Promise = response.data diff --git a/src/api/providers/fetchers/vercel-ai-gateway.ts b/src/api/providers/fetchers/vercel-ai-gateway.ts index 3b6852c28d52..2e2514fffc16 100644 --- a/src/api/providers/fetchers/vercel-ai-gateway.ts +++ b/src/api/providers/fetchers/vercel-ai-gateway.ts @@ -52,12 +52,15 @@ type VercelAiGatewayModelsResponse = z.infer> { +export async function getVercelAiGatewayModels( + options?: ApiHandlerOptions, + signal?: AbortSignal, +): Promise> { const models: Record = {} const baseURL = "https://ai-gateway.vercel.sh/v1" try { - const response = await axios.get(`${baseURL}/models`) + const response = await axios.get(`${baseURL}/models`, { signal }) const result = vercelAiGatewayModelsResponseSchema.safeParse(response.data) const data = result.success ? result.data.data : response.data.data diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 580b17331194..f237e069a174 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -219,6 +219,9 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH this.models = models this.endpoints = endpoints + console.log( + `[${new Date().toISOString()}] [openrouter] fetchModel() models=${Object.keys(models).length}, endpoints=${Object.keys(endpoints).length}`, + ) return this.getModel() } diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 749e8d090d82..851c7278de69 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -214,16 +214,18 @@ describe("webviewMessageHandler - requestRouterModels", () => { mockGetModels.mockResolvedValue(mockModels) await webviewMessageHandler(mockClineProvider, { - type: "requestRouterModels", + type: "requestRouterModelsAll", }) // Verify getModels was called for each provider expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" }) - expect(mockGetModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" }) + expect(mockGetModels).toHaveBeenCalledWith( + expect.objectContaining({ provider: "requesty", apiKey: "requesty-key" }), + ) expect(mockGetModels).toHaveBeenCalledWith({ provider: "glama" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "vercel-ai-gateway" }) - expect(mockGetModels).toHaveBeenCalledWith({ provider: "deepinfra" }) + expect(mockGetModels).toHaveBeenCalledWith(expect.objectContaining({ provider: "deepinfra" })) expect(mockGetModels).toHaveBeenCalledWith( expect.objectContaining({ provider: "roo", @@ -281,7 +283,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { mockGetModels.mockResolvedValue(mockModels) await webviewMessageHandler(mockClineProvider, { - type: "requestRouterModels", + type: "requestRouterModelsAll", values: { litellmApiKey: "message-litellm-key", litellmBaseUrl: "http://message-url:4000", @@ -319,7 +321,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { mockGetModels.mockResolvedValue(mockModels) await webviewMessageHandler(mockClineProvider, { - type: "requestRouterModels", + type: "requestRouterModelsAll", // No values provided }) @@ -372,7 +374,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm await webviewMessageHandler(mockClineProvider, { - type: "requestRouterModels", + type: "requestRouterModelsAll", }) // Verify successful providers are included @@ -430,7 +432,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm await webviewMessageHandler(mockClineProvider, { - type: "requestRouterModels", + type: "requestRouterModelsAll", }) // Verify error handling for different error types @@ -496,7 +498,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { mockGetModels.mockResolvedValue(mockModels) await webviewMessageHandler(mockClineProvider, { - type: "requestRouterModels", + type: "requestRouterModelsAll", values: { litellmApiKey: "message-key", litellmBaseUrl: "http://message-url", diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index e32b818a96e9..71b7954b0949 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -12,8 +12,10 @@ import { type TelemetrySetting, TelemetryEventName, UserSettingsConfig, - DEFAULT_CHECKPOINT_TIMEOUT_SECONDS, } from "@roo-code/types" + +// Default checkpoint timeout (from global-settings.ts) +const DEFAULT_CHECKPOINT_TIMEOUT_SECONDS = 15 import { CloudService } from "@roo-code/cloud" import { TelemetryService } from "@roo-code/telemetry" @@ -24,7 +26,7 @@ import { ClineProvider } from "./ClineProvider" import { handleCheckpointRestoreOperation } from "./checkpointRestoreHandler" import { changeLanguage, t } from "../../i18n" import { Package } from "../../shared/package" -import { type RouterName, type ModelRecord, toRouterName } from "../../shared/api" +import { type RouterName, type ModelRecord, type RouterModels, isRouterName, toRouterName } from "../../shared/api" import { MessageEnhancer } from "./messageEnhancer" import { @@ -52,9 +54,9 @@ import { openMention } from "../mentions" import { getWorkspacePath } from "../../utils/path" import { Mode, defaultModeSlug } from "../../shared/modes" import { getModels, flushModels } from "../../api/providers/fetchers/modelCache" -import { GetModelsOptions } from "../../shared/api" import { generateSystemPrompt } from "./generateSystemPrompt" import { getCommand } from "../../utils/commands" +import { fetchRouterModels } from "../../services/router-models" const ALLOWED_VSCODE_SETTINGS = new Set(["terminal.integrated.inheritEnv"]) @@ -499,6 +501,15 @@ export const webviewMessageHandler = async ( }) provider.isViewLaunched = true + + // Phase 2: Warm caches on activation by fetching all providers once + // This happens in background without blocking the UI + webviewMessageHandler(provider, { type: "requestRouterModelsAll" }, marketplaceManager).catch((error) => { + provider.log( + `Background router models fetch on activation failed: ${error instanceof Error ? error.message : String(error)}`, + ) + }) + break case "newTask": // Initializing new instance of Cline will make sure that any @@ -754,136 +765,78 @@ export const webviewMessageHandler = async ( const routerNameFlush: RouterName = toRouterName(message.text) await flushModels(routerNameFlush) break - case "requestRouterModels": + case "requestRouterModels": { + // Phase 2: Scope to active provider during chat/task flows const { apiConfiguration } = await provider.getState() - const routerModels: Record = { - openrouter: {}, - "vercel-ai-gateway": {}, - huggingface: {}, - litellm: {}, - deepinfra: {}, - "io-intelligence": {}, - requesty: {}, - unbound: {}, - glama: {}, - ollama: {}, - lmstudio: {}, - roo: {}, - } - - const safeGetModels = async (options: GetModelsOptions): Promise => { - try { - return await getModels(options) - } catch (error) { - console.error( - `Failed to fetch models in webviewMessageHandler requestRouterModels for ${options.provider}:`, - error, - ) + const { routerModels, errors } = await fetchRouterModels({ + apiConfiguration, + activeProviderOnly: true, + litellmOverrides: message?.values + ? { + apiKey: message.values.litellmApiKey, + baseUrl: message.values.litellmBaseUrl, + } + : undefined, + }) - throw error // Re-throw to be caught by Promise.allSettled. - } - } + // Send error notifications for failed providers + errors.forEach((err) => { + provider.log(`Error fetching models for ${err.provider}: ${err.error}`) + provider.postMessageToWebview({ + type: "singleRouterModelFetchResponse", + success: false, + error: err.error, + values: { provider: err.provider }, + }) + }) - const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [ - { key: "openrouter", options: { provider: "openrouter" } }, - { - key: "requesty", - options: { - provider: "requesty", - apiKey: apiConfiguration.requestyApiKey, - baseUrl: apiConfiguration.requestyBaseUrl, - }, - }, - { key: "glama", options: { provider: "glama" } }, - { key: "unbound", options: { provider: "unbound", apiKey: apiConfiguration.unboundApiKey } }, - { key: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } }, - { - key: "deepinfra", - options: { - provider: "deepinfra", - apiKey: apiConfiguration.deepInfraApiKey, - baseUrl: apiConfiguration.deepInfraBaseUrl, - }, - }, - { - key: "roo", - options: { - provider: "roo", - baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy", - apiKey: CloudService.hasInstance() - ? CloudService.instance.authService?.getSessionToken() - : undefined, - }, - }, - ] + provider.postMessageToWebview({ type: "routerModels", routerModels: routerModels as RouterModels }) + break + } + case "requestRouterModelsAll": { + // Settings and activation: fetch all providers (legacy behavior) + const { apiConfiguration } = await provider.getState() - // Add IO Intelligence if API key is provided. - const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey + const { routerModels, errors } = await fetchRouterModels({ + apiConfiguration, + activeProviderOnly: false, + litellmOverrides: message?.values + ? { + apiKey: message.values.litellmApiKey, + baseUrl: message.values.litellmBaseUrl, + } + : undefined, + }) - if (ioIntelligenceApiKey) { - modelFetchPromises.push({ - key: "io-intelligence", - options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey }, + // Send error notifications for failed providers + errors.forEach((err) => { + provider.log(`Error fetching models for ${err.provider}: ${err.error}`) + provider.postMessageToWebview({ + type: "singleRouterModelFetchResponse", + success: false, + error: err.error, + values: { provider: err.provider }, }) - } - - // Don't fetch Ollama and LM Studio models by default anymore. - // They have their own specific handlers: requestOllamaModels and requestLmStudioModels. - - const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey - const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl + }) - if (litellmApiKey && litellmBaseUrl) { - modelFetchPromises.push({ - key: "litellm", - options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, + // Send ollama/lmstudio-specific messages if models were fetched + if (routerModels.ollama && Object.keys(routerModels.ollama).length > 0) { + provider.postMessageToWebview({ + type: "ollamaModels", + ollamaModels: routerModels.ollama, + }) + } + if (routerModels.lmstudio && Object.keys(routerModels.lmstudio).length > 0) { + provider.postMessageToWebview({ + type: "lmStudioModels", + lmStudioModels: routerModels.lmstudio, }) } - const results = await Promise.allSettled( - modelFetchPromises.map(async ({ key, options }) => { - const models = await safeGetModels(options) - return { key, models } // The key is `ProviderName` here. - }), - ) - - results.forEach((result, index) => { - const routerName = modelFetchPromises[index].key - - if (result.status === "fulfilled") { - routerModels[routerName] = result.value.models - - // Ollama and LM Studio settings pages still need these events. - if (routerName === "ollama" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "ollamaModels", - ollamaModels: result.value.models, - }) - } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { - provider.postMessageToWebview({ - type: "lmStudioModels", - lmStudioModels: result.value.models, - }) - } - } else { - // Handle rejection: Post a specific error message for this provider. - const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason) - console.error(`Error fetching models for ${routerName}:`, result.reason) - - routerModels[routerName] = {} // Ensure it's an empty object in the main routerModels message. - - provider.postMessageToWebview({ - type: "singleRouterModelFetchResponse", - success: false, - error: errorMessage, - values: { provider: routerName }, - }) - } - }) - - provider.postMessageToWebview({ type: "routerModels", routerModels }) + provider.postMessageToWebview({ type: "routerModels", routerModels: routerModels as RouterModels }) break + } case "requestOllamaModels": { // Specific handler for Ollama models only. const { apiConfiguration: ollamaApiConfig } = await provider.getState() @@ -901,7 +854,8 @@ export const webviewMessageHandler = async ( provider.postMessageToWebview({ type: "ollamaModels", ollamaModels: ollamaModels }) } } catch (error) { - // Silently fail - user hasn't configured Ollama yet + // Silently fail - user hasn't configured Ollama yet (debug level only) + // Using console.debug since this is expected when Ollama isn't configured console.debug("Ollama models fetch failed:", error) } break @@ -925,7 +879,8 @@ export const webviewMessageHandler = async ( }) } } catch (error) { - // Silently fail - user hasn't configured LM Studio yet. + // Silently fail - user hasn't configured LM Studio yet (debug level only) + // Using console.debug since this is expected when LM Studio isn't configured console.debug("LM Studio models fetch failed:", error) } break @@ -990,7 +945,9 @@ export const webviewMessageHandler = async ( huggingFaceModels: huggingFaceModelsResponse.models, }) } catch (error) { - console.error("Failed to fetch Hugging Face models:", error) + provider.log( + `Failed to fetch Hugging Face models: ${error instanceof Error ? error.message : String(error)}`, + ) provider.postMessageToWebview({ type: "huggingFaceModels", huggingFaceModels: [] }) } break @@ -1016,10 +973,11 @@ export const webviewMessageHandler = async ( } break case "checkpointDiff": - const result = checkoutDiffPayloadSchema.safeParse(message.payload) + const diffResult = checkoutDiffPayloadSchema.safeParse(message.payload) - if (result.success) { - await provider.getCurrentTask()?.checkpointDiff(result.data) + if (diffResult.success) { + // Cast to the correct CheckpointDiffOptions type (mode can be "from-init" | "checkpoint" | "to-current" | "full") + await provider.getCurrentTask()?.checkpointDiff(diffResult.data as any) } break @@ -1308,7 +1266,7 @@ export const webviewMessageHandler = async ( break case "checkpointTimeout": const checkpointTimeout = message.value ?? DEFAULT_CHECKPOINT_TIMEOUT_SECONDS - await updateGlobalState("checkpointTimeout", checkpointTimeout) + await provider.contextProxy.setValue("checkpointTimeout", checkpointTimeout) await provider.postStateToWebview() break case "browserViewportSize": @@ -1658,14 +1616,6 @@ export const webviewMessageHandler = async ( await updateGlobalState("includeDiagnosticMessages", includeValue) await provider.postStateToWebview() break - case "includeCurrentTime": - await updateGlobalState("includeCurrentTime", message.bool ?? true) - await provider.postStateToWebview() - break - case "includeCurrentCost": - await updateGlobalState("includeCurrentCost", message.bool ?? true) - await provider.postStateToWebview() - break case "maxDiagnosticMessages": await updateGlobalState("maxDiagnosticMessages", message.value ?? 50) await provider.postStateToWebview() @@ -1701,6 +1651,14 @@ export const webviewMessageHandler = async ( await updateGlobalState("includeTaskHistoryInEnhance", message.bool ?? true) await provider.postStateToWebview() break + case "includeCurrentTime": + await updateGlobalState("includeCurrentTime", message.bool ?? true) + await provider.postStateToWebview() + break + case "includeCurrentCost": + await updateGlobalState("includeCurrentCost", message.bool ?? true) + await provider.postStateToWebview() + break case "condensingApiConfigId": await updateGlobalState("condensingApiConfigId", message.text) await provider.postStateToWebview() diff --git a/src/services/router-models/__tests__/router-models-service.spec.ts b/src/services/router-models/__tests__/router-models-service.spec.ts new file mode 100644 index 000000000000..55c05de9cdda --- /dev/null +++ b/src/services/router-models/__tests__/router-models-service.spec.ts @@ -0,0 +1,266 @@ +import { describe, it, expect, beforeEach, vi } from "vitest" +import type { Mock } from "vitest" +import type { ProviderSettings } from "@roo-code/types" +import { fetchRouterModels } from "../index" +import { getModels } from "../../../api/providers/fetchers/modelCache" +import { CloudService } from "@roo-code/cloud" + +// Mock dependencies +vi.mock("../../../api/providers/fetchers/modelCache") +vi.mock("@roo-code/cloud") + +const mockGetModels = getModels as Mock +const mockCloudService = CloudService as any + +describe("RouterModelsService", () => { + const mockModels = { + "test-model": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Test model", + }, + } + + const baseApiConfiguration: ProviderSettings = { + apiProvider: "openrouter", + openRouterApiKey: "test-key", + requestyApiKey: "requesty-key", + unboundApiKey: "unbound-key", + ioIntelligenceApiKey: "io-key", + deepInfraApiKey: "deepinfra-key", + litellmApiKey: "litellm-key", + litellmBaseUrl: "http://localhost:4000", + } + + beforeEach(() => { + vi.clearAllMocks() + mockGetModels.mockResolvedValue(mockModels) + mockCloudService.hasInstance = vi.fn().mockReturnValue(false) + }) + + describe("fetchRouterModels", () => { + it("fetches all providers when activeProviderOnly is false", async () => { + const result = await fetchRouterModels({ + apiConfiguration: baseApiConfiguration, + activeProviderOnly: false, + }) + + // Should fetch all standard providers + expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" }) + expect(mockGetModels).toHaveBeenCalledWith( + expect.objectContaining({ provider: "requesty", apiKey: "requesty-key" }), + ) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "glama" }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "vercel-ai-gateway" }) + expect(mockGetModels).toHaveBeenCalledWith( + expect.objectContaining({ provider: "deepinfra", apiKey: "deepinfra-key" }), + ) + expect(mockGetModels).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "roo", + baseUrl: "https://api.roocode.com/proxy", + }), + ) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "io-intelligence", apiKey: "io-key" }) + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "litellm", + apiKey: "litellm-key", + baseUrl: "http://localhost:4000", + }) + + // Should return models for all providers + expect(result.routerModels).toHaveProperty("openrouter") + expect(result.routerModels).toHaveProperty("requesty") + expect(result.routerModels).toHaveProperty("glama") + expect(result.errors).toEqual([]) + }) + + it("fetches only active provider when activeProviderOnly is true", async () => { + const result = await fetchRouterModels({ + apiConfiguration: { ...baseApiConfiguration, apiProvider: "openrouter" }, + activeProviderOnly: true, + }) + + // Should only fetch openrouter + expect(mockGetModels).toHaveBeenCalledTimes(1) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" }) + + // Should return models only for openrouter + expect(result.routerModels.openrouter).toEqual(mockModels) + expect(result.errors).toEqual([]) + }) + + it("includes ollama when it is the active provider", async () => { + const config: ProviderSettings = { + ...baseApiConfiguration, + apiProvider: "ollama", + ollamaBaseUrl: "http://localhost:11434", + } + + await fetchRouterModels({ + apiConfiguration: config, + activeProviderOnly: true, + }) + + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "ollama", + baseUrl: "http://localhost:11434", + apiKey: undefined, + }) + }) + + it("includes lmstudio when it is the active provider", async () => { + const config: ProviderSettings = { + ...baseApiConfiguration, + apiProvider: "lmstudio", + lmStudioBaseUrl: "http://localhost:1234", + } + + await fetchRouterModels({ + apiConfiguration: config, + activeProviderOnly: true, + }) + + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "lmstudio", + baseUrl: "http://localhost:1234", + }) + }) + + it("includes huggingface when it is the active provider", async () => { + const config: ProviderSettings = { + ...baseApiConfiguration, + apiProvider: "huggingface", + } + + await fetchRouterModels({ + apiConfiguration: config, + activeProviderOnly: true, + }) + + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "huggingface", + }) + }) + + it("uses litellmOverrides when provided", async () => { + await fetchRouterModels({ + apiConfiguration: { ...baseApiConfiguration, litellmApiKey: undefined, litellmBaseUrl: undefined }, + activeProviderOnly: false, + litellmOverrides: { + apiKey: "override-key", + baseUrl: "http://override:5000", + }, + }) + + expect(mockGetModels).toHaveBeenCalledWith({ + provider: "litellm", + apiKey: "override-key", + baseUrl: "http://override:5000", + }) + }) + + it("handles provider fetch errors gracefully", async () => { + mockGetModels + .mockResolvedValueOnce(mockModels) // openrouter succeeds + .mockRejectedValueOnce(new Error("Requesty API error")) // requesty fails + .mockResolvedValueOnce(mockModels) // glama succeeds + + const result = await fetchRouterModels({ + apiConfiguration: baseApiConfiguration, + activeProviderOnly: false, + }) + + // Should have errors for failed provider + expect(result.errors).toHaveLength(1) + expect(result.errors[0]).toEqual({ + provider: "requesty", + error: "Requesty API error", + }) + + // Should have empty object for failed provider + expect(result.routerModels.requesty).toEqual({}) + + // Should have models for successful providers + expect(result.routerModels.openrouter).toEqual(mockModels) + }) + + it("skips litellm when no api key or base url provided", async () => { + const config: ProviderSettings = { + ...baseApiConfiguration, + litellmApiKey: undefined, + litellmBaseUrl: undefined, + } + + await fetchRouterModels({ + apiConfiguration: config, + activeProviderOnly: false, + }) + + // Should not call getModels for litellm + expect(mockGetModels).not.toHaveBeenCalledWith(expect.objectContaining({ provider: "litellm" })) + }) + + it("skips io-intelligence when no api key provided", async () => { + const config: ProviderSettings = { + ...baseApiConfiguration, + ioIntelligenceApiKey: undefined, + } + + await fetchRouterModels({ + apiConfiguration: config, + activeProviderOnly: false, + }) + + // Should not call getModels for io-intelligence + expect(mockGetModels).not.toHaveBeenCalledWith(expect.objectContaining({ provider: "io-intelligence" })) + }) + + it("uses roo session token when CloudService is available", async () => { + const mockAuthService = { + getSessionToken: vi.fn().mockReturnValue("session-token-123"), + } + + vi.mocked(CloudService.hasInstance).mockReturnValue(true) + Object.defineProperty(CloudService, "instance", { + get: () => ({ authService: mockAuthService }), + configurable: true, + }) + + await fetchRouterModels({ + apiConfiguration: baseApiConfiguration, + activeProviderOnly: false, + }) + + expect(mockGetModels).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "roo", + apiKey: "session-token-123", + }), + ) + }) + + it("initializes all providers with empty objects", async () => { + const result = await fetchRouterModels({ + apiConfiguration: { apiProvider: "openrouter" } as ProviderSettings, + activeProviderOnly: true, + }) + + // All providers should be initialized even if not fetched + expect(result.routerModels).toHaveProperty("openrouter") + expect(result.routerModels).toHaveProperty("requesty") + expect(result.routerModels).toHaveProperty("glama") + expect(result.routerModels).toHaveProperty("unbound") + expect(result.routerModels).toHaveProperty("vercel-ai-gateway") + expect(result.routerModels).toHaveProperty("deepinfra") + expect(result.routerModels).toHaveProperty("roo") + expect(result.routerModels).toHaveProperty("litellm") + expect(result.routerModels).toHaveProperty("ollama") + expect(result.routerModels).toHaveProperty("lmstudio") + expect(result.routerModels).toHaveProperty("huggingface") + expect(result.routerModels).toHaveProperty("io-intelligence") + }) + }) +}) diff --git a/src/services/router-models/index.ts b/src/services/router-models/index.ts new file mode 100644 index 000000000000..52a27ee59973 --- /dev/null +++ b/src/services/router-models/index.ts @@ -0,0 +1,171 @@ +import type { ProviderSettings } from "@roo-code/types" +import { CloudService } from "@roo-code/cloud" +import type { RouterName, ModelRecord, GetModelsOptions } from "../../shared/api" +import { getModels } from "../../api/providers/fetchers/modelCache" + +export interface RouterModelsFetchOptions { + apiConfiguration: ProviderSettings + activeProviderOnly?: boolean + litellmOverrides?: { + apiKey?: string + baseUrl?: string + } +} + +export interface RouterModelsFetchResult { + routerModels: Partial> + errors: Array<{ + provider: RouterName + error: string + }> +} + +/** + * Builds the list of provider fetch options based on configuration and mode. + */ +function buildProviderFetchList( + options: RouterModelsFetchOptions, +): Array<{ key: RouterName; options: GetModelsOptions }> { + const { apiConfiguration, activeProviderOnly, litellmOverrides } = options + + const allFetches: Array<{ key: RouterName; options: GetModelsOptions }> = [ + { key: "openrouter", options: { provider: "openrouter" } }, + { + key: "requesty", + options: { + provider: "requesty", + apiKey: apiConfiguration.requestyApiKey, + baseUrl: apiConfiguration.requestyBaseUrl, + }, + }, + { key: "glama", options: { provider: "glama" } }, + { key: "unbound", options: { provider: "unbound", apiKey: apiConfiguration.unboundApiKey } }, + { key: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } }, + { + key: "deepinfra", + options: { + provider: "deepinfra", + apiKey: apiConfiguration.deepInfraApiKey, + baseUrl: apiConfiguration.deepInfraBaseUrl, + }, + }, + { + key: "roo", + options: { + provider: "roo", + baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy", + apiKey: CloudService.hasInstance() ? CloudService.instance.authService?.getSessionToken() : undefined, + }, + }, + ] + + // Include local providers when in active-provider mode and they are selected + if (activeProviderOnly) { + const activeProvider = apiConfiguration.apiProvider + + if (activeProvider === "ollama") { + allFetches.push({ + key: "ollama", + options: { + provider: "ollama", + baseUrl: apiConfiguration.ollamaBaseUrl, + apiKey: apiConfiguration.ollamaApiKey, + }, + }) + } + if (activeProvider === "lmstudio") { + allFetches.push({ + key: "lmstudio", + options: { + provider: "lmstudio", + baseUrl: apiConfiguration.lmStudioBaseUrl, + }, + }) + } + if (activeProvider === "huggingface") { + allFetches.push({ + key: "huggingface", + options: { + provider: "huggingface", + }, + }) + } + } + + // Add IO Intelligence if API key is provided + if (apiConfiguration.ioIntelligenceApiKey) { + allFetches.push({ + key: "io-intelligence", + options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey }, + }) + } + + // Add LiteLLM if configured (with potential overrides from message) + const litellmApiKey = apiConfiguration.litellmApiKey || litellmOverrides?.apiKey + const litellmBaseUrl = apiConfiguration.litellmBaseUrl || litellmOverrides?.baseUrl + if (litellmApiKey && litellmBaseUrl) { + allFetches.push({ + key: "litellm", + options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, + }) + } + + return allFetches +} + +/** + * Fetches router models based on the provided options. + * Can fetch all providers or only the active provider. + */ +export async function fetchRouterModels(options: RouterModelsFetchOptions): Promise { + const { apiConfiguration, activeProviderOnly } = options + + // Initialize empty results for all providers + const routerModels: Partial> = { + openrouter: {}, + "vercel-ai-gateway": {}, + huggingface: {}, + litellm: {}, + deepinfra: {}, + "io-intelligence": {}, + requesty: {}, + unbound: {}, + glama: {}, + ollama: {}, + lmstudio: {}, + roo: {}, + } + + const errors: Array<{ provider: RouterName; error: string }> = [] + + // Build fetch list + const fetchList = buildProviderFetchList(options) + + // Filter to active provider if requested + const activeProvider = apiConfiguration.apiProvider as RouterName | undefined + const modelFetchPromises = + activeProviderOnly && activeProvider ? fetchList.filter(({ key }) => key === activeProvider) : fetchList + + // Execute fetches + const results = await Promise.allSettled( + modelFetchPromises.map(async ({ key, options }) => { + const models = await getModels(options) + return { key, models } + }), + ) + + // Process results + results.forEach((result, index) => { + const routerName = modelFetchPromises[index].key + + if (result.status === "fulfilled") { + routerModels[routerName] = result.value.models + } else { + const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason) + routerModels[routerName] = {} + errors.push({ provider: routerName, error: errorMessage }) + } + }) + + return { routerModels, errors } +} diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 5929e7a950eb..f8d0c5a817b8 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -294,8 +294,6 @@ export type ExtensionState = Pick< | "openRouterImageGenerationSelectedModel" | "includeTaskHistoryInEnhance" | "reasoningBlockCollapsed" - | "includeCurrentTime" - | "includeCurrentCost" > & { version: string clineMessages: ClineMessage[] @@ -324,6 +322,9 @@ export type ExtensionState = Pick< mcpEnabled: boolean enableMcpServerCreation: boolean + includeCurrentTime?: boolean + includeCurrentCost?: boolean + mode: Mode customModes: ModeConfig[] toolRequirements?: Record // Map of tool names to their requirements (e.g. {"apply_diff": true} if diffEnabled) diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 9c475186288f..6924bc9927fb 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -66,6 +66,7 @@ export interface WebviewMessage { | "resetState" | "flushRouterModels" | "requestRouterModels" + | "requestRouterModelsAll" | "requestOpenAiModels" | "requestOllamaModels" | "requestLmStudioModels" diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 9e4d585c97c2..7e958b52ec0e 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -42,7 +42,7 @@ import { import { vscode } from "@src/utils/vscode" import { validateApiConfigurationExcludingModelErrors, getModelValidationError } from "@src/utils/validate" import { useAppTranslation } from "@src/i18n/TranslationContext" -import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" +import { useRouterModelsAll } from "@src/components/ui/hooks/useRouterModelsAll" import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel" import { useExtensionState } from "@src/context/ExtensionStateContext" import { @@ -188,7 +188,7 @@ const ApiOptions = ({ info: selectedModelInfo, } = useSelectedModel(apiConfiguration) - const { data: routerModels, refetch: refetchRouterModels } = useRouterModels() + const { data: routerModels, refetch: refetchRouterModels } = useRouterModelsAll() const { data: openRouterModelProviders } = useOpenRouterModelProviders(apiConfiguration?.openRouterModelId, { enabled: diff --git a/webview-ui/src/components/settings/providers/Unbound.tsx b/webview-ui/src/components/settings/providers/Unbound.tsx index 15826d0c0b40..de3d906ddaf9 100644 --- a/webview-ui/src/components/settings/providers/Unbound.tsx +++ b/webview-ui/src/components/settings/providers/Unbound.tsx @@ -100,11 +100,11 @@ export const Unbound = ({ window.addEventListener("message", messageHandler) }) - vscode.postMessage({ type: "requestRouterModels" }) + vscode.postMessage({ type: "requestRouterModelsAll" }) await modelsPromise - await queryClient.invalidateQueries({ queryKey: ["routerModels"] }) + await queryClient.invalidateQueries({ queryKey: ["routerModelsAll"] }) // After refreshing models, check if current model is in the updated list // If not, select the first available model diff --git a/webview-ui/src/components/ui/hooks/useRouterModelsAll.ts b/webview-ui/src/components/ui/hooks/useRouterModelsAll.ts new file mode 100644 index 000000000000..e3d38a4df0c4 --- /dev/null +++ b/webview-ui/src/components/ui/hooks/useRouterModelsAll.ts @@ -0,0 +1,38 @@ +import { useQuery } from "@tanstack/react-query" + +import { RouterModels } from "@roo/api" +import { ExtensionMessage } from "@roo/ExtensionMessage" + +import { vscode } from "@src/utils/vscode" + +const getRouterModelsAll = async () => + new Promise((resolve, reject) => { + const cleanup = () => { + window.removeEventListener("message", handler) + } + + const timeout = setTimeout(() => { + cleanup() + reject(new Error("Router models (all) request timed out")) + }, 10000) + + const handler = (event: MessageEvent) => { + const message: ExtensionMessage = event.data + + if (message.type === "routerModels") { + clearTimeout(timeout) + cleanup() + + if (message.routerModels) { + resolve(message.routerModels) + } else { + reject(new Error("No router models in response")) + } + } + } + + window.addEventListener("message", handler) + vscode.postMessage({ type: "requestRouterModelsAll" }) + }) + +export const useRouterModelsAll = () => useQuery({ queryKey: ["routerModelsAll"], queryFn: getRouterModelsAll })