Skip to content
3 changes: 2 additions & 1 deletion src/api/providers/fetchers/deepinfra.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ const DeepInfraModelsResponseSchema = z.object({ data: z.array(DeepInfraModelSch
export async function getDeepInfraModels(
apiKey?: string,
baseUrl: string = "https://api.deepinfra.com/v1/openai",
signal?: AbortSignal,
): Promise<Record<string, ModelInfo>> {
const headers: Record<string, string> = { ...DEFAULT_HEADERS }
if (apiKey) headers["Authorization"] = `Bearer ${apiKey}`

const url = `${baseUrl.replace(/\/$/, "")}/models`
const models: Record<string, ModelInfo> = {}

const response = await axios.get(url, { headers })
const response = await axios.get(url, { headers, signal })
const parsed = DeepInfraModelsResponseSchema.safeParse(response.data)
const data = parsed.success ? parsed.data.data : response.data?.data || []

Expand Down
4 changes: 2 additions & 2 deletions src/api/providers/fetchers/glama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import type { ModelInfo } from "@roo-code/types"

import { parseApiPrice } from "../../../shared/cost"

export async function getGlamaModels(): Promise<Record<string, ModelInfo>> {
export async function getGlamaModels(signal?: AbortSignal): Promise<Record<string, ModelInfo>> {
const models: Record<string, ModelInfo> = {}

try {
const response = await axios.get("https://glama.ai/api/gateway/v1/models")
const response = await axios.get("https://glama.ai/api/gateway/v1/models", { signal })
const rawModels = response.data

for (const rawModel of rawModels) {
Expand Down
6 changes: 3 additions & 3 deletions src/api/providers/fetchers/huggingface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFacePr
* @returns A promise that resolves to a record of model IDs to model info
* @throws Will throw an error if the request fails
*/
export async function getHuggingFaceModels(): Promise<ModelRecord> {
export async function getHuggingFaceModels(signal?: AbortSignal): Promise<ModelRecord> {
const now = Date.now()

if (cache && now - cache.timestamp < HUGGINGFACE_CACHE_DURATION) {
Expand All @@ -128,7 +128,7 @@ export async function getHuggingFaceModels(): Promise<ModelRecord> {
Pragma: "no-cache",
"Cache-Control": "no-cache",
},
timeout: 10000,
signal,
})

const result = huggingFaceApiResponseSchema.safeParse(response.data)
Expand Down Expand Up @@ -236,7 +236,7 @@ export async function getHuggingFaceModelsWithMetadata(): Promise<HuggingFaceMod
Pragma: "no-cache",
"Cache-Control": "no-cache",
},
timeout: 10000,
signal: AbortSignal.timeout(30000),
})

const models = response.data?.data || []
Expand Down
4 changes: 2 additions & 2 deletions src/api/providers/fetchers/io-intelligence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function parseIOIntelligenceModel(model: IOIntelligenceModel): ModelInfo {
* Fetches available models from IO Intelligence
* <mcreference link="https://docs.io.net/reference/get-started-with-io-intelligence-api" index="1">1</mcreference>
*/
export async function getIOIntelligenceModels(apiKey?: string): Promise<ModelRecord> {
export async function getIOIntelligenceModels(apiKey?: string, signal?: AbortSignal): Promise<ModelRecord> {
const now = Date.now()

if (cache && now - cache.timestamp < IO_INTELLIGENCE_CACHE_DURATION) {
Expand All @@ -108,7 +108,7 @@ export async function getIOIntelligenceModels(apiKey?: string): Promise<ModelRec
"https://api.intelligence.io.solutions/api/v1/models",
{
headers,
timeout: 10_000,
signal,
},
)

Expand Down
5 changes: 2 additions & 3 deletions src/api/providers/fetchers/litellm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { DEFAULT_HEADERS } from "../constants"
* @returns A promise that resolves to a record of model IDs to model info
* @throws Will throw an error if the request fails or the response is not as expected.
*/
export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise<ModelRecord> {
export async function getLiteLLMModels(apiKey: string, baseUrl: string, signal?: AbortSignal): Promise<ModelRecord> {
try {
const headers: Record<string, string> = {
"Content-Type": "application/json",
Expand All @@ -27,8 +27,7 @@ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise
// Normalize the pathname by removing trailing slashes and multiple slashes
urlObj.pathname = urlObj.pathname.replace(/\/+$/, "").replace(/\/+/g, "/") + "/v1/model/info"
const url = urlObj.href
// Added timeout to prevent indefinite hanging
const response = await axios.get(url, { headers, timeout: 5000 })
const response = await axios.get(url, { headers, signal })
const models: ModelRecord = {}

// Process the model info from the response
Expand Down
7 changes: 5 additions & 2 deletions src/api/providers/fetchers/lmstudio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelIn
return modelInfo
}

export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Promise<Record<string, ModelInfo>> {
export async function getLMStudioModels(
baseUrl = "http://localhost:1234",
signal?: AbortSignal,
): Promise<Record<string, ModelInfo>> {
// clear the set of models that have full details loaded
modelsWithLoadedDetails.clear()
// clearing the input can leave an empty string; use the default in that case
Expand All @@ -66,7 +69,7 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom

// test the connection to LM Studio first
// errors will be caught further down
await axios.get(`${baseUrl}/v1/models`)
await axios.get(`${baseUrl}/v1/models`, { signal })

const client = new LMStudioClient({ baseUrl: lmsUrl })

Expand Down
164 changes: 124 additions & 40 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import fs from "fs/promises"

import NodeCache from "node-cache"

import type { ProviderName } from "@roo-code/types"

import { safeWriteJson } from "../../../utils/safeWriteJson"

import { ContextProxy } from "../../../core/config/ContextProxy"
Expand All @@ -28,6 +26,9 @@ import { getRooModels } from "./roo"

const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })

// Coalesce concurrent fetches per provider within this extension host
const inFlightModelFetches = new Map<RouterName, Promise<ModelRecord>>()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In-flight coalescing key is too coarse. Coalescing by provider only can return incorrect results for providers whose model lists depend on options (baseUrl/apiKey), e.g. 'litellm', 'requesty', 'roo', 'ollama', 'lmstudio', 'deepinfra', 'io-intelligence'. Two concurrent calls with different options will share the same in-flight promise and also write to the same file cache key, causing cross-config mixing. Consider deriving a composite key: provider + normalized baseUrl + an auth/materialized identity hint (e.g., a hash of apiKey presence or token subject), and include this in both the in-flight map key and the file-cache filename.


async function writeModels(router: RouterName, data: ModelRecord) {
const filename = `${router}_models.json`
const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath)
Expand Down Expand Up @@ -56,82 +57,165 @@ async function readModels(router: RouterName): Promise<ModelRecord | undefined>
export const getModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
const { provider } = options

let models = getModelsFromCache(provider)

if (models) {
return models
// 1) Try memory cache
const cached = getModelsFromCache(provider)
if (cached) {
return cached
}

// 2) Try file cache snapshot (Option A), then kick off background refresh
try {
const file = await readModels(provider)
if (file && Object.keys(file).length > 0) {
memoryCache.set(provider, file)

// Start background refresh if not already in-flight (do not await)
if (!inFlightModelFetches.has(provider)) {
const signal = AbortSignal.timeout(30_000)
const bgPromise = (async (): Promise<ModelRecord> => {
let models: ModelRecord = {}
switch (provider) {
case "openrouter":
models = await getOpenRouterModels(undefined, signal)
break
case "requesty":
models = await getRequestyModels(options.baseUrl, options.apiKey, signal)
break
case "glama":
models = await getGlamaModels(signal)
break
case "unbound":
models = await getUnboundModels(options.apiKey, signal)
break
case "litellm":
models = await getLiteLLMModels(options.apiKey as string, options.baseUrl as string, signal)
break
case "ollama":
models = await getOllamaModels(options.baseUrl, options.apiKey, signal)
break
case "lmstudio":
models = await getLMStudioModels(options.baseUrl, signal)
break
case "deepinfra":
models = await getDeepInfraModels(options.apiKey, options.baseUrl, signal)
break
case "io-intelligence":
models = await getIOIntelligenceModels(options.apiKey, signal)
break
case "vercel-ai-gateway":
models = await getVercelAiGatewayModels(undefined, signal)
break
case "huggingface":
models = await getHuggingFaceModels(signal)
break
case "roo": {
const rooBaseUrl =
options.baseUrl ?? process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy"
models = await getRooModels(rooBaseUrl, options.apiKey, signal)
break
}
default:
throw new Error(`Unknown provider: ${provider}`)
}

memoryCache.set(provider, models)
await writeModels(provider, models).catch((err) => {
console.error(
`[modelCache] Error writing ${provider} to file cache during background refresh:`,
err instanceof Error ? err.message : String(err),
)
})
return models || {}
})()

inFlightModelFetches.set(provider, bgPromise)
Promise.resolve(bgPromise)
.catch((err) => {
console.error(
`[modelCache] Background refresh failed for ${provider}:`,
err instanceof Error ? err.message : String(err),
)
})
.finally(() => inFlightModelFetches.delete(provider))
}

return file
}
} catch {
// ignore file read errors; fall through to network/coalesce path
}

// 3) Coalesce concurrent fetches
const existing = inFlightModelFetches.get(provider)
if (existing) {
return existing
}

// 4) Network fetch wrapped as a single in-flight promise for this provider
const signal = AbortSignal.timeout(30_000)
const fetchPromise = (async (): Promise<ModelRecord> => {
let models: ModelRecord = {}
switch (provider) {
case "openrouter":
models = await getOpenRouterModels()
models = await getOpenRouterModels(undefined, signal)
break
case "requesty":
// Requesty models endpoint requires an API key for per-user custom policies.
models = await getRequestyModels(options.baseUrl, options.apiKey)
models = await getRequestyModels(options.baseUrl, options.apiKey, signal)
break
case "glama":
models = await getGlamaModels()
models = await getGlamaModels(signal)
break
case "unbound":
// Unbound models endpoint requires an API key to fetch application specific models.
models = await getUnboundModels(options.apiKey)
models = await getUnboundModels(options.apiKey, signal)
break
case "litellm":
// Type safety ensures apiKey and baseUrl are always provided for LiteLLM.
models = await getLiteLLMModels(options.apiKey, options.baseUrl)
models = await getLiteLLMModels(options.apiKey as string, options.baseUrl as string, signal)
break
case "ollama":
models = await getOllamaModels(options.baseUrl, options.apiKey)
models = await getOllamaModels(options.baseUrl, options.apiKey, signal)
break
case "lmstudio":
models = await getLMStudioModels(options.baseUrl)
models = await getLMStudioModels(options.baseUrl, signal)
break
case "deepinfra":
models = await getDeepInfraModels(options.apiKey, options.baseUrl)
models = await getDeepInfraModels(options.apiKey, options.baseUrl, signal)
break
case "io-intelligence":
models = await getIOIntelligenceModels(options.apiKey)
models = await getIOIntelligenceModels(options.apiKey, signal)
break
case "vercel-ai-gateway":
models = await getVercelAiGatewayModels()
models = await getVercelAiGatewayModels(undefined, signal)
break
case "huggingface":
models = await getHuggingFaceModels()
models = await getHuggingFaceModels(signal)
break
case "roo": {
// Roo Code Cloud provider requires baseUrl and optional apiKey
const rooBaseUrl =
options.baseUrl ?? process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy"
models = await getRooModels(rooBaseUrl, options.apiKey)
models = await getRooModels(rooBaseUrl, options.apiKey, signal)
break
}
default: {
// Ensures router is exhaustively checked if RouterName is a strict union.
const exhaustiveCheck: never = provider
throw new Error(`Unknown provider: ${exhaustiveCheck}`)
throw new Error(`Unknown provider: ${provider}`)
}
}

// Cache the fetched models (even if empty, to signify a successful fetch with no models).
memoryCache.set(provider, models)

await writeModels(provider, models).catch((err) =>
console.error(`[getModels] Error writing ${provider} models to file cache:`, err),
)
await writeModels(provider, models).catch((err) => {
console.error(
`[modelCache] Error writing ${provider} to file cache after network fetch:`,
err instanceof Error ? err.message : String(err),
)
})

try {
models = await readModels(provider)
} catch (error) {
console.error(`[getModels] error reading ${provider} models from file cache`, error)
}
return models || {}
} catch (error) {
// Log the error and re-throw it so the caller can handle it (e.g., show a UI message).
console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error)
})()

throw error // Re-throw the original error to be handled by the caller.
inFlightModelFetches.set(provider, fetchPromise)
try {
return await fetchPromise
} finally {
inFlightModelFetches.delete(provider)
}
}

Expand All @@ -144,6 +228,6 @@ export const flushModels = async (router: RouterName) => {
memoryCache.del(router)
}

export function getModelsFromCache(provider: ProviderName) {
export function getModelsFromCache(provider: RouterName) {
return memoryCache.get<ModelRecord>(provider)
}
Loading
Loading