From 49521cc4998bc273551630f77980e3d78cb70c84 Mon Sep 17 00:00:00 2001 From: Hannes Rudolph Date: Wed, 5 Nov 2025 14:43:22 -0700 Subject: [PATCH 1/3] Implements the model metadata caching refactor: persists resolvedModelInfo, removes TTL-based auto-expiration, adds explicit Refresh models flow, activation-time self-healing, and gated reinit on provider/model/baseUrl changes. --- packages/types/src/provider-settings.ts | 2 + src/__tests__/extension.spec.ts | 438 +++++++++--------- .../settings-schema.resolvedModelInfo.spec.ts | 24 + .../webviewMessageHandler.refresh.spec.ts | 155 +++++++ src/api/providers/__tests__/chutes.spec.ts | 42 +- src/api/providers/__tests__/deepinfra.spec.ts | 62 +++ src/api/providers/__tests__/glama.spec.ts | 38 +- .../__tests__/io-intelligence.spec.ts | 34 ++ src/api/providers/__tests__/lite-llm.spec.ts | 52 ++- .../providers/__tests__/openrouter.spec.ts | 53 +++ src/api/providers/__tests__/requesty.spec.ts | 29 ++ src/api/providers/__tests__/roo.spec.ts | 45 ++ src/api/providers/__tests__/unbound.spec.ts | 38 +- .../__tests__/vercel-ai-gateway.spec.ts | 46 ++ src/api/providers/chutes.ts | 4 +- src/api/providers/deepinfra.ts | 12 +- .../fetchers/__tests__/modelCache.spec.ts | 157 ++++++- src/api/providers/fetchers/modelCache.ts | 14 +- .../providers/fetchers/modelEndpointCache.ts | 20 +- src/api/providers/fetchers/openrouter.ts | 6 + src/api/providers/glama.ts | 4 +- src/api/providers/io-intelligence.ts | 21 +- src/api/providers/lite-llm.ts | 4 +- src/api/providers/openrouter.ts | 14 +- src/api/providers/requesty.ts | 10 +- src/api/providers/roo.ts | 12 +- src/api/providers/router-provider.ts | 10 +- src/api/providers/unbound.ts | 4 +- src/api/providers/vercel-ai-gateway.ts | 4 +- src/core/sliding-window/index.ts | 13 + src/core/webview/ClineProvider.ts | 83 +++- src/core/webview/webviewMessageHandler.ts | 108 ++++- src/extension.ts | 61 ++- src/shared/ExtensionMessage.ts | 1 + .../src/components/settings/ApiOptions.tsx | 6 - .../src/components/settings/ModelPicker.tsx | 53 ++- ...lPicker.persist-resolvedModelInfo.spec.tsx | 90 ++++ .../__tests__/ModelPicker.refresh.spec.tsx | 139 ++++++ .../hooks/__tests__/useRouterModels.spec.tsx | 158 +++++++ .../components/ui/hooks/useRouterModels.ts | 12 +- .../components/ui/hooks/useSelectedModel.ts | 12 +- 41 files changed, 1772 insertions(+), 318 deletions(-) create mode 100644 src/__tests__/settings-schema.resolvedModelInfo.spec.ts create mode 100644 src/__tests__/webviewMessageHandler.refresh.spec.ts create mode 100644 src/api/providers/__tests__/deepinfra.spec.ts create mode 100644 webview-ui/src/components/settings/__tests__/ModelPicker.persist-resolvedModelInfo.spec.tsx create mode 100644 webview-ui/src/components/settings/__tests__/ModelPicker.refresh.spec.tsx create mode 100644 webview-ui/src/components/ui/hooks/__tests__/useRouterModels.spec.tsx diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 23e0a548d113..5c2fc81fb91a 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -179,6 +179,8 @@ const baseProviderSettingsSchema = z.object({ reasoningEffort: reasoningEffortWithMinimalSchema.optional(), modelMaxTokens: z.number().optional(), modelMaxThinkingTokens: z.number().optional(), + // Persisted resolved model metadata (Phase 1 Step 1) + resolvedModelInfo: modelInfoSchema.optional(), // Model verbosity. verbosity: verbosityLevelsSchema.optional(), diff --git a/src/__tests__/extension.spec.ts b/src/__tests__/extension.spec.ts index 6c12c473fb99..149d177002ae 100644 --- a/src/__tests__/extension.spec.ts +++ b/src/__tests__/extension.spec.ts @@ -1,258 +1,254 @@ -// npx vitest run __tests__/extension.spec.ts - -import type * as vscode from "vscode" -import type { AuthState } from "@roo-code/types" +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" vi.mock("vscode", () => ({ window: { - createOutputChannel: vi.fn().mockReturnValue({ - appendLine: vi.fn(), - }), - registerWebviewViewProvider: vi.fn(), - registerUriHandler: vi.fn(), - tabGroups: { - onDidChangeTabs: vi.fn(), - }, - onDidChangeActiveTextEditor: vi.fn(), + createTextEditorDecorationType: vi.fn().mockReturnValue({ dispose: vi.fn() }), + showErrorMessage: vi.fn(), + showInformationMessage: vi.fn(), }, + CodeActionKind: { + QuickFix: { value: "quickfix" }, + RefactorRewrite: { value: "refactor.rewrite" }, + }, + RelativePattern: vi.fn(), workspace: { - registerTextDocumentContentProvider: vi.fn(), - getConfiguration: vi.fn().mockReturnValue({ - get: vi.fn().mockReturnValue([]), - }), createFileSystemWatcher: vi.fn().mockReturnValue({ - onDidCreate: vi.fn(), onDidChange: vi.fn(), + onDidCreate: vi.fn(), onDidDelete: vi.fn(), - dispose: vi.fn(), }), - onDidChangeWorkspaceFolders: vi.fn(), - }, - languages: { - registerCodeActionsProvider: vi.fn(), - }, - commands: { - executeCommand: vi.fn(), - }, - env: { - language: "en", - }, - ExtensionMode: { - Production: 1, + getConfiguration: vi.fn().mockReturnValue({ update: vi.fn() }), }, + env: { language: "en" }, })) -vi.mock("@dotenvx/dotenvx", () => ({ - config: vi.fn(), +vi.mock("../api", () => ({ + buildApiHandler: vi.fn(), })) -const mockBridgeOrchestratorDisconnect = vi.fn().mockResolvedValue(undefined) - -vi.mock("@roo-code/cloud", () => ({ - CloudService: { - createInstance: vi.fn(), - hasInstance: vi.fn().mockReturnValue(true), - get instance() { - return { - off: vi.fn(), - on: vi.fn(), - getUserInfo: vi.fn().mockReturnValue(null), - isTaskSyncEnabled: vi.fn().mockReturnValue(false), - } - }, - }, - BridgeOrchestrator: { - disconnect: mockBridgeOrchestratorDisconnect, - }, - getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), -})) - -vi.mock("@roo-code/telemetry", () => ({ - TelemetryService: { - createInstance: vi.fn().mockReturnValue({ - register: vi.fn(), - setProvider: vi.fn(), - shutdown: vi.fn(), - }), - get instance() { - return { - register: vi.fn(), - setProvider: vi.fn(), - shutdown: vi.fn(), - } - }, - }, - PostHogTelemetryClient: vi.fn(), -})) +import { ensureResolvedModelInfo } from "../extension" +import { buildApiHandler } from "../api" +import { ClineProvider } from "../core/webview/ClineProvider" -vi.mock("../utils/outputChannelLogger", () => ({ - createOutputChannelLogger: vi.fn().mockReturnValue(vi.fn()), - createDualLogger: vi.fn().mockReturnValue(vi.fn()), -})) +describe("activation-time resolvedModelInfo", () => { + let logSpy: ReturnType + let warnSpy: ReturnType -vi.mock("../shared/package", () => ({ - Package: { - name: "test-extension", - outputChannel: "Test Output", - version: "1.0.0", - }, -})) + beforeEach(() => { + vi.clearAllMocks() + logSpy = vi.spyOn(console, "log").mockImplementation(() => {}) + warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}) + }) -vi.mock("../shared/language", () => ({ - formatLanguage: vi.fn().mockReturnValue("en"), -})) + afterEach(() => { + logSpy.mockRestore() + warnSpy.mockRestore() + }) -vi.mock("../core/config/ContextProxy", () => ({ - ContextProxy: { - getInstance: vi.fn().mockResolvedValue({ - getValue: vi.fn(), - setValue: vi.fn(), - getValues: vi.fn().mockReturnValue({}), - getProviderSettings: vi.fn().mockReturnValue({}), - }), - }, -})) + it("populates missing resolvedModelInfo for a dynamic provider on activation", async () => { + const provider: any = { + getState: vi.fn().mockResolvedValue({ + apiConfiguration: { apiProvider: "openrouter", openRouterModelId: "openrouter/model" }, + currentApiConfigName: "default", + }), + upsertProviderProfile: vi.fn().mockResolvedValue("id"), + } + + const info = { contextWindow: 4000, maxTokens: 8192, supportsPromptCache: true } + const handler = { + fetchModel: vi.fn().mockResolvedValue({ info }), + getModel: vi.fn().mockReturnValue({ id: "openrouter/model", info }), + } + ;(buildApiHandler as any).mockReturnValue(handler) + + await ensureResolvedModelInfo(provider) + + expect(buildApiHandler).toHaveBeenCalled() + expect(provider.upsertProviderProfile).toHaveBeenCalledWith( + "default", + expect.objectContaining({ resolvedModelInfo: info }), + true, + ) + expect(logSpy.mock.calls.some((c: any[]) => String(c.join(" ")).includes("Populating resolvedModelInfo"))).toBe( + true, + ) + }) -vi.mock("../integrations/editor/DiffViewProvider", () => ({ - DIFF_VIEW_URI_SCHEME: "test-diff-scheme", -})) + it("skips when resolvedModelInfo is valid", async () => { + const resolved = { contextWindow: 16000, maxTokens: 8000 } + const provider: any = { + getState: vi.fn().mockResolvedValue({ + apiConfiguration: { apiProvider: "openrouter", resolvedModelInfo: resolved }, + currentApiConfigName: "default", + }), + upsertProviderProfile: vi.fn(), + } + + await ensureResolvedModelInfo(provider) + + expect(buildApiHandler).not.toHaveBeenCalled() + expect(provider.upsertProviderProfile).not.toHaveBeenCalled() + expect( + logSpy.mock.calls.some((c: any[]) => String(c.join(" ")).includes("Using existing resolvedModelInfo")), + ).toBe(true) + }) -vi.mock("../integrations/terminal/TerminalRegistry", () => ({ - TerminalRegistry: { - initialize: vi.fn(), - cleanup: vi.fn(), - }, -})) + it("skips for static providers", async () => { + const provider: any = { + getState: vi.fn().mockResolvedValue({ + apiConfiguration: { apiProvider: "anthropic", apiModelId: "claude-3-5-sonnet" }, + currentApiConfigName: "default", + }), + upsertProviderProfile: vi.fn(), + } -vi.mock("../services/mcp/McpServerManager", () => ({ - McpServerManager: { - cleanup: vi.fn().mockResolvedValue(undefined), - getInstance: vi.fn().mockResolvedValue(null), - unregisterProvider: vi.fn(), - }, -})) + await ensureResolvedModelInfo(provider) -vi.mock("../services/code-index/manager", () => ({ - CodeIndexManager: { - getInstance: vi.fn().mockReturnValue(null), - }, -})) + expect(buildApiHandler).not.toHaveBeenCalled() + expect(provider.upsertProviderProfile).not.toHaveBeenCalled() + }) +}) -vi.mock("../services/mdm/MdmService", () => ({ - MdmService: { - createInstance: vi.fn().mockResolvedValue(null), - }, -})) +describe("settings save gating (Phase 3.2)", () => { + let logSpy: ReturnType -vi.mock("../utils/migrateSettings", () => ({ - migrateSettings: vi.fn().mockResolvedValue(undefined), -})) + beforeEach(() => { + vi.clearAllMocks() + logSpy = vi.spyOn(console, "log").mockImplementation(() => {}) + }) -vi.mock("../utils/autoImportSettings", () => ({ - autoImportSettings: vi.fn().mockResolvedValue(undefined), -})) + afterEach(() => { + logSpy.mockRestore() + }) -vi.mock("../extension/api", () => ({ - API: vi.fn().mockImplementation(() => ({})), -})) + const bindProvider = (impl: any) => (ClineProvider.prototype.upsertProviderProfile as any).bind(impl) + + it("does not reinit on unrelated setting change and preserves resolvedModelInfo", async () => { + const prevConfig = { + apiProvider: "openrouter", + openRouterModelId: "openrouter/model", + openRouterBaseUrl: "https://openrouter.ai/api/v1", + resolvedModelInfo: { contextWindow: 4000, maxTokens: 8192 }, + modelTemperature: 0.1, + } + + const nextConfig = { + ...prevConfig, + modelTemperature: 0.2, // unrelated change + } + + const provider: any = { + providerSettingsManager: { + saveConfig: vi.fn().mockResolvedValue("id"), + listConfig: vi.fn().mockResolvedValue([]), + setModeConfig: vi.fn().mockResolvedValue(undefined), + }, + updateGlobalState: vi.fn().mockResolvedValue(undefined), + contextProxy: { setProviderSettings: vi.fn().mockResolvedValue(undefined) }, + getState: vi.fn().mockResolvedValue({ apiConfiguration: prevConfig, mode: "architect" }), + getCurrentTask: vi.fn().mockReturnValue({ api: undefined }), + postStateToWebview: vi.fn().mockResolvedValue(undefined), + log: vi.fn(), + } + + ;(buildApiHandler as any).mockReturnValue({}) // handler if reinit (should NOT be called) + + const upsert = bindProvider(provider) + await upsert("default", nextConfig, true) + + expect(provider.providerSettingsManager.saveConfig).toHaveBeenCalledWith("default", nextConfig) + expect(provider.contextProxy.setProviderSettings).toHaveBeenCalledWith(nextConfig) + expect(buildApiHandler).not.toHaveBeenCalled() + expect( + logSpy.mock.calls.some((c: any[]) => + String(c.join(" ")).includes("[model-cache/save] No reinit: provider/model/baseUrl unchanged"), + ), + ).toBe(true) + // Ensure resolvedModelInfo remained intact in persisted payload + expect((provider.providerSettingsManager.saveConfig as any).mock.calls[0][1].resolvedModelInfo).toEqual( + prevConfig.resolvedModelInfo, + ) + }) -vi.mock("../activate", () => ({ - handleUri: vi.fn(), - registerCommands: vi.fn(), - registerCodeActions: vi.fn(), - registerTerminalActions: vi.fn(), - CodeActionProvider: vi.fn().mockImplementation(() => ({ - providedCodeActionKinds: [], - })), -})) + it("reinit when provider/model/baseUrl changes (modelId change)", async () => { + const prevConfig = { + apiProvider: "openrouter", + openRouterModelId: "openrouter/model", + openRouterBaseUrl: "https://openrouter.ai/api/v1", + } -vi.mock("../i18n", () => ({ - initializeI18n: vi.fn(), - t: vi.fn((key) => key), -})) + const nextConfig = { + ...prevConfig, + openRouterModelId: "openrouter/other-model", // model change should trigger reinit + } -describe("extension.ts", () => { - let mockContext: vscode.ExtensionContext - let authStateChangedHandler: - | ((data: { state: AuthState; previousState: AuthState }) => void | Promise) - | undefined + const handler = {} + ;(buildApiHandler as any).mockReturnValue(handler) - beforeEach(() => { - vi.clearAllMocks() - mockBridgeOrchestratorDisconnect.mockClear() + const task: any = { api: undefined } - mockContext = { - extensionPath: "/test/path", - globalState: { - get: vi.fn().mockReturnValue(undefined), - update: vi.fn(), + const provider: any = { + providerSettingsManager: { + saveConfig: vi.fn().mockResolvedValue("id"), + listConfig: vi.fn().mockResolvedValue([]), + setModeConfig: vi.fn().mockResolvedValue(undefined), }, - subscriptions: [], - } as unknown as vscode.ExtensionContext - - authStateChangedHandler = undefined + updateGlobalState: vi.fn().mockResolvedValue(undefined), + contextProxy: { setProviderSettings: vi.fn().mockResolvedValue(undefined) }, + getState: vi.fn().mockResolvedValue({ apiConfiguration: prevConfig, mode: "architect" }), + getCurrentTask: vi.fn().mockReturnValue(task), + postStateToWebview: vi.fn().mockResolvedValue(undefined), + log: vi.fn(), + } + + const upsert = bindProvider(provider) + await upsert("default", nextConfig, true) + + expect(provider.providerSettingsManager.saveConfig).toHaveBeenCalledWith("default", nextConfig) + expect(buildApiHandler).toHaveBeenCalledWith(nextConfig) + expect(task.api).toBe(handler) + expect( + logSpy.mock.calls.some((c: any[]) => + String(c.join(" ")).includes("[model-cache/save] Reinit: relevant fields changed"), + ), + ).toBe(true) }) - test("authStateChangedHandler calls BridgeOrchestrator.disconnect when logged-out event fires", async () => { - const { CloudService, BridgeOrchestrator } = await import("@roo-code/cloud") - - // Capture the auth state changed handler. - vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => { - if (handlers?.["auth-state-changed"]) { - authStateChangedHandler = handlers["auth-state-changed"] - } - - return { - off: vi.fn(), - on: vi.fn(), - telemetryClient: null, - } as any - }) - - // Activate the extension. - const { activate } = await import("../extension") - await activate(mockContext) - - // Verify handler was registered. - expect(authStateChangedHandler).toBeDefined() - - // Trigger logout. - await authStateChangedHandler!({ - state: "logged-out" as AuthState, - previousState: "logged-in" as AuthState, - }) - - // Verify BridgeOrchestrator.disconnect was called - expect(mockBridgeOrchestratorDisconnect).toHaveBeenCalled() - }) + it("reinit when router baseUrl changes", async () => { + const prevConfig = { + apiProvider: "requesty", + requestyModelId: "requesty/model", + requestyBaseUrl: "https://api.requesty.ai", + } + + const nextConfig = { + ...prevConfig, + requestyBaseUrl: "https://custom.requesty.ai", // baseUrl change should trigger reinit + } - test("authStateChangedHandler does not call BridgeOrchestrator.disconnect for other states", async () => { - const { CloudService } = await import("@roo-code/cloud") - - // Capture the auth state changed handler. - vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => { - if (handlers?.["auth-state-changed"]) { - authStateChangedHandler = handlers["auth-state-changed"] - } - - return { - off: vi.fn(), - on: vi.fn(), - telemetryClient: null, - } as any - }) - - // Activate the extension. - const { activate } = await import("../extension") - await activate(mockContext) - - // Trigger login. - await authStateChangedHandler!({ - state: "logged-in" as AuthState, - previousState: "logged-out" as AuthState, - }) - - // Verify BridgeOrchestrator.disconnect was NOT called. - expect(mockBridgeOrchestratorDisconnect).not.toHaveBeenCalled() + const handler = {} + ;(buildApiHandler as any).mockReturnValue(handler) + + const task: any = { api: undefined } + + const provider: any = { + providerSettingsManager: { + saveConfig: vi.fn().mockResolvedValue("id"), + listConfig: vi.fn().mockResolvedValue([]), + setModeConfig: vi.fn().mockResolvedValue(undefined), + }, + updateGlobalState: vi.fn().mockResolvedValue(undefined), + contextProxy: { setProviderSettings: vi.fn().mockResolvedValue(undefined) }, + getState: vi.fn().mockResolvedValue({ apiConfiguration: prevConfig, mode: "architect" }), + getCurrentTask: vi.fn().mockReturnValue(task), + postStateToWebview: vi.fn().mockResolvedValue(undefined), + log: vi.fn(), + } + + const upsert = bindProvider(provider) + await upsert("default", nextConfig, true) + + expect(buildApiHandler).toHaveBeenCalledWith(nextConfig) + expect(task.api).toBe(handler) }) }) diff --git a/src/__tests__/settings-schema.resolvedModelInfo.spec.ts b/src/__tests__/settings-schema.resolvedModelInfo.spec.ts new file mode 100644 index 000000000000..9a977f2671ed --- /dev/null +++ b/src/__tests__/settings-schema.resolvedModelInfo.spec.ts @@ -0,0 +1,24 @@ +import { providerSettingsSchema, type ProviderSettings, type ModelInfo, PROVIDER_SETTINGS_KEYS } from "@roo-code/types" + +describe("ProviderSettings schema resolvedModelInfo", () => { + it("accepts and preserves resolvedModelInfo", () => { + const resolved: ModelInfo = { + contextWindow: 16384, + supportsPromptCache: true, + maxTokens: 8192, + } + + const input: ProviderSettings = { + apiProvider: "openrouter", + openRouterModelId: "openrouter/some-model", + resolvedModelInfo: resolved, + } + + const parsed = providerSettingsSchema.parse(input) + expect(parsed.resolvedModelInfo).toEqual(resolved) + }) + + it("includes resolvedModelInfo in PROVIDER_SETTINGS_KEYS", () => { + expect(PROVIDER_SETTINGS_KEYS).toContain("resolvedModelInfo") + }) +}) diff --git a/src/__tests__/webviewMessageHandler.refresh.spec.ts b/src/__tests__/webviewMessageHandler.refresh.spec.ts new file mode 100644 index 000000000000..85da3025cdf4 --- /dev/null +++ b/src/__tests__/webviewMessageHandler.refresh.spec.ts @@ -0,0 +1,155 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import type { ClineProvider } from "../core/webview/ClineProvider" +import type { ModelRecord } from "../shared/api" + +vi.mock("../api/providers/fetchers/modelCache", () => ({ + flushModels: vi.fn(), + getModels: vi.fn(), +})) + +vi.mock("../api/providers/fetchers/modelEndpointCache", () => ({ + flushModelProviders: vi.fn(), +})) + +vi.mock("@roo-code/cloud", () => ({ + CloudService: { + hasInstance: () => false, + }, +})) + +import { webviewMessageHandler } from "../core/webview/webviewMessageHandler" +import { flushModels, getModels } from "../api/providers/fetchers/modelCache" +import { flushModelProviders } from "../api/providers/fetchers/modelEndpointCache" + +const flushModelsMock = vi.mocked(flushModels) +const getModelsMock = vi.mocked(getModels) +const flushModelProvidersMock = vi.mocked(flushModelProviders) + +describe("webviewMessageHandler.flushRouterModels", () => { + let logSpy: ReturnType + let warnSpy: ReturnType + + beforeEach(() => { + vi.clearAllMocks() + logSpy = vi.spyOn(console, "log").mockImplementation(() => {}) + warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {}) + }) + + afterEach(() => { + logSpy.mockRestore() + warnSpy.mockRestore() + }) + + it("flushes caches, refetches models, persists resolvedModelInfo, and posts success", async () => { + const apiConfiguration = { + apiProvider: "openrouter", + openRouterModelId: "openrouter/model", + } + const getState = vi.fn().mockResolvedValue({ + apiConfiguration, + currentApiConfigName: "default", + }) + const postMessageToWebview = vi.fn() + const upsertProviderProfile = vi.fn().mockResolvedValue(undefined) + + const provider = { + getState, + postMessageToWebview, + upsertProviderProfile, + } as unknown as ClineProvider + + const models: ModelRecord = { + "openrouter/model": { + contextWindow: 32000, + maxTokens: 16000, + supportsImages: false, + supportsPromptCache: true, + }, + } + + getModelsMock.mockResolvedValue(models) + + await webviewMessageHandler(provider, { + type: "flushRouterModels", + } as any) + + expect(flushModelsMock).toHaveBeenCalledWith("openrouter") + expect(flushModelProvidersMock).toHaveBeenCalledWith("openrouter", "openrouter/model") + expect(getModelsMock).toHaveBeenCalledWith({ provider: "openrouter" }) + expect(upsertProviderProfile).toHaveBeenCalledWith( + "default", + expect.objectContaining({ + resolvedModelInfo: models["openrouter/model"], + }), + true, + ) + expect(postMessageToWebview).toHaveBeenCalledWith({ type: "flushRouterModelsResult", success: true }) + }) + + it("supports router overrides supplied via message text when no provider model is selected", async () => { + const getState = vi.fn().mockResolvedValue({ + apiConfiguration: {}, + currentApiConfigName: undefined, + }) + const postMessageToWebview = vi.fn() + const upsertProviderProfile = vi.fn() + + const provider = { + getState, + postMessageToWebview, + upsertProviderProfile, + } as unknown as ClineProvider + + getModelsMock.mockResolvedValue({}) + + await webviewMessageHandler(provider, { + type: "flushRouterModels", + text: "requesty", + } as any) + + expect(flushModelsMock).toHaveBeenCalledWith("requesty") + expect(flushModelProvidersMock).not.toHaveBeenCalled() + expect(getModelsMock).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "requesty", + }), + ) + expect(upsertProviderProfile).not.toHaveBeenCalled() + expect(postMessageToWebview).toHaveBeenCalledWith({ type: "flushRouterModelsResult", success: true }) + }) + + it("posts failure response when refetching models throws", async () => { + const apiConfiguration = { + apiProvider: "openrouter", + openRouterModelId: "openrouter/model", + } + const getState = vi.fn().mockResolvedValue({ + apiConfiguration, + currentApiConfigName: "default", + }) + const postMessageToWebview = vi.fn() + const upsertProviderProfile = vi.fn().mockResolvedValue(undefined) + + const provider = { + getState, + postMessageToWebview, + upsertProviderProfile, + } as unknown as ClineProvider + + const failure = new Error("failed to refresh") + getModelsMock.mockRejectedValue(failure) + + await webviewMessageHandler(provider, { + type: "flushRouterModels", + } as any) + + expect(flushModelsMock).toHaveBeenCalledWith("openrouter") + expect(flushModelProvidersMock).toHaveBeenCalledWith("openrouter", "openrouter/model") + expect(upsertProviderProfile).not.toHaveBeenCalled() + expect(postMessageToWebview).toHaveBeenCalledWith({ + type: "flushRouterModelsResult", + success: false, + error: failure.message, + }) + }) +}) diff --git a/src/api/providers/__tests__/chutes.spec.ts b/src/api/providers/__tests__/chutes.spec.ts index b4c933d4cc57..2bd97533b167 100644 --- a/src/api/providers/__tests__/chutes.spec.ts +++ b/src/api/providers/__tests__/chutes.spec.ts @@ -249,11 +249,45 @@ describe("ChutesHandler", () => { apiModelId: testModelId, chutesApiKey: "test-chutes-api-key", }) - // Note: getModel() returns fallback default without calling fetchModel - // Since we haven't called fetchModel, it returns the default chutesDefaultModelId - // which is DeepSeek-R1-0528, therefore temperature will be DEEP_SEEK_DEFAULT_TEMPERATURE + // With new priority behavior, id stays as requested; non-DeepSeek defaults to 0.5 const model = handlerWithModel.getModel() - // The default model is DeepSeek-R1, so it returns DEEP_SEEK_DEFAULT_TEMPERATURE + expect(model.info.temperature).toBe(0.5) + }) +}) + +// Phase 2: getModel priority tests +describe("ChutesHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over cache and default", () => { + const resolved = { maxTokens: 1024, contextWindow: 131072, supportsImages: false, supportsPromptCache: false } + const handler = new ChutesHandler({ + chutesApiKey: "k", + apiModelId: "deepseek-ai/DeepSeek-R1-0528", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("deepseek-ai/DeepSeek-R1-0528") + // Info includes resolved fields plus provider temperature decoration + expect(model.info).toEqual(expect.objectContaining(resolved)) expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE) }) + + it("uses memory cache when no resolvedModelInfo", () => { + const handler = new ChutesHandler({ chutesApiKey: "k", apiModelId: "unsloth/Llama-3.3-70B-Instruct" } as any) + ;(handler as any).models = { + "unsloth/Llama-3.3-70B-Instruct": { + maxTokens: 2048, + contextWindow: 262144, + supportsImages: false, + supportsPromptCache: false, + }, + } + const model = handler.getModel() + expect(model.info.maxTokens).toBe(2048) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new ChutesHandler({ chutesApiKey: "k", apiModelId: "unknown/model" } as any) + const model = handler.getModel() + expect(model.info).toEqual(expect.objectContaining({ contextWindow: expect.any(Number) })) + }) }) diff --git a/src/api/providers/__tests__/deepinfra.spec.ts b/src/api/providers/__tests__/deepinfra.spec.ts new file mode 100644 index 000000000000..e2f17cb8a477 --- /dev/null +++ b/src/api/providers/__tests__/deepinfra.spec.ts @@ -0,0 +1,62 @@ +// npx vitest run api/providers/__tests__/deepinfra.spec.ts + +import { describe, it, expect, vi, beforeEach } from "vitest" + +import { DeepInfraHandler } from "../deepinfra" +import type { ApiHandlerOptions } from "../../../shared/api" + +vi.mock("openai", () => ({ + default: class MockOpenAI { + baseURL: string + apiKey: string + chat = { completions: { create: vi.fn() } } + constructor(opts: any) { + this.baseURL = opts.baseURL + this.apiKey = opts.apiKey + } + }, +})) + +describe("DeepInfraHandler getModel priority", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("prefers options.resolvedModelInfo over cache and default", () => { + const resolved = { maxTokens: 1234, contextWindow: 56789, supportsImages: false, supportsPromptCache: true } + const handler = new DeepInfraHandler({ + deepInfraApiKey: "k", + deepInfraModelId: "meta/llama-3", + resolvedModelInfo: resolved, + } as any as ApiHandlerOptions) + const model = handler.getModel() + expect(model.id).toBe("meta/llama-3") + expect(model.info).toBe(resolved) + }) + + it("uses memory cache when no resolvedModelInfo", () => { + const handler = new DeepInfraHandler({ + deepInfraApiKey: "k", + deepInfraModelId: "openai/gpt-4o", + } as any as ApiHandlerOptions) + ;(handler as any).models = { + "openai/gpt-4o": { + maxTokens: 999, + contextWindow: 128000, + supportsImages: true, + supportsPromptCache: false, + }, + } + const model = handler.getModel() + expect(model.info.maxTokens).toBe(999) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new DeepInfraHandler({ + deepInfraApiKey: "k", + deepInfraModelId: "unknown/model", + } as any as ApiHandlerOptions) + const model = handler.getModel() + expect(model.info).toEqual(expect.objectContaining({ contextWindow: expect.any(Number) })) + }) +}) diff --git a/src/api/providers/__tests__/glama.spec.ts b/src/api/providers/__tests__/glama.spec.ts index 9f82cad3ba47..0b67645e054d 100644 --- a/src/api/providers/__tests__/glama.spec.ts +++ b/src/api/providers/__tests__/glama.spec.ts @@ -225,8 +225,44 @@ describe("GlamaHandler", () => { it("should return default model when invalid model provided", async () => { const handlerWithInvalidModel = new GlamaHandler({ ...mockOptions, glamaModelId: "invalid/model" }) const modelInfo = await handlerWithInvalidModel.fetchModel() - expect(modelInfo.id).toBe("anthropic/claude-3-7-sonnet") + // Priority now preserves requested id with default info + expect(modelInfo.id).toBe("invalid/model") expect(modelInfo.info).toBeDefined() }) }) }) + +// Phase 2: getModel priority tests +describe("GlamaHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over cache and default", () => { + const resolved = { maxTokens: 1111, contextWindow: 222222, supportsImages: true, supportsPromptCache: true } + const handler = new GlamaHandler({ + glamaApiKey: "k", + glamaModelId: "anthropic/claude-3-7-sonnet", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("anthropic/claude-3-7-sonnet") + expect(model.info).toBe(resolved) + }) + + it("uses memory cache when no resolvedModelInfo", () => { + const handler = new GlamaHandler({ glamaApiKey: "k", glamaModelId: "openai/gpt-4o" } as any) + ;(handler as any).models = { + "openai/gpt-4o": { + maxTokens: 3333, + contextWindow: 444444, + supportsImages: true, + supportsPromptCache: false, + }, + } + const model = handler.getModel() + expect(model.info.maxTokens).toBe(3333) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new GlamaHandler({ glamaApiKey: "k", glamaModelId: "unknown/model" } as any) + const model = handler.getModel() + expect(model.info).toEqual(expect.objectContaining({ contextWindow: expect.any(Number) })) + }) +}) diff --git a/src/api/providers/__tests__/io-intelligence.spec.ts b/src/api/providers/__tests__/io-intelligence.spec.ts index 3b46b79ee25f..91819ff597bf 100644 --- a/src/api/providers/__tests__/io-intelligence.spec.ts +++ b/src/api/providers/__tests__/io-intelligence.spec.ts @@ -274,6 +274,40 @@ describe("IOIntelligenceHandler", () => { }) }) + // Phase 2: getModel priority tests + describe("IOIntelligenceHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over provider models", () => { + const resolved = { + maxTokens: 5001, + contextWindow: 999999, + supportsImages: true, + supportsPromptCache: false, + } + const handler = new IOIntelligenceHandler({ + ioIntelligenceApiKey: "k", + ioIntelligenceModelId: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8") + expect(model.info).toBe(resolved) + }) + + it("uses provider models when no resolvedModelInfo", () => { + const handler = new IOIntelligenceHandler({ + ioIntelligenceApiKey: "k", + ioIntelligenceModelId: "openai/gpt-oss-120b", + } as any) + const model = handler.getModel() + expect(model.info).toEqual( + expect.objectContaining({ + contextWindow: expect.any(Number), + supportsPromptCache: expect.any(Boolean), + }), + ) + }) + }) + it("should use default model when no model is specified", () => { const handlerWithoutModel = new IOIntelligenceHandler({ ...mockOptions, diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts index fe62ad3922cf..9012cdc87437 100644 --- a/src/api/providers/__tests__/lite-llm.spec.ts +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -327,8 +327,8 @@ describe("LiteLLMHandler", () => { } handler = new LiteLLMHandler(optionsWithGPT5) - // Force fetchModel to return undefined maxTokens - vi.spyOn(handler as any, "fetchModel").mockResolvedValue({ + // Force getModel to return undefined maxTokens + vi.spyOn(handler as any, "getModel").mockReturnValue({ id: "gpt-5", info: { ...litellmDefaultModelInfo, maxTokens: undefined }, }) @@ -370,8 +370,8 @@ describe("LiteLLMHandler", () => { } handler = new LiteLLMHandler(optionsWithGPT5) - // Force fetchModel to return undefined maxTokens - vi.spyOn(handler as any, "fetchModel").mockResolvedValue({ + // Force getModel to return undefined maxTokens + vi.spyOn(handler as any, "getModel").mockReturnValue({ id: "gpt-5", info: { ...litellmDefaultModelInfo, maxTokens: undefined }, }) @@ -387,4 +387,48 @@ describe("LiteLLMHandler", () => { expect(createCall.max_completion_tokens).toBeUndefined() }) }) + + // Phase 2: getModel priority tests + describe("LiteLLMHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over cache and default", () => { + const resolved = { + maxTokens: 2468, + contextWindow: 135790, + supportsImages: false, + supportsPromptCache: true, + } + const handler = new LiteLLMHandler({ + litellmApiKey: "k", + litellmBaseUrl: "http://localhost:4000", + litellmModelId: "gpt-4", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("gpt-4") + expect(model.info).toBe(resolved) + }) + + it("uses memory cache when no resolvedModelInfo", () => { + const handler = new LiteLLMHandler({ + litellmApiKey: "k", + litellmBaseUrl: "http://localhost:4000", + litellmModelId: "llama-3", + } as any) + ;(handler as any).models = { + "llama-3": { maxTokens: 5000, contextWindow: 90000, supportsImages: false, supportsPromptCache: false }, + } + const model = handler.getModel() + expect(model.info.maxTokens).toBe(5000) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new LiteLLMHandler({ + litellmApiKey: "k", + litellmBaseUrl: "http://localhost:4000", + litellmModelId: "unknown/model", + } as any) + const model = handler.getModel() + expect(model.info).toEqual(expect.objectContaining({ contextWindow: expect.any(Number) })) + }) + }) }) diff --git a/src/api/providers/__tests__/openrouter.spec.ts b/src/api/providers/__tests__/openrouter.spec.ts index f5067ef34c96..c4c2ebe086c2 100644 --- a/src/api/providers/__tests__/openrouter.spec.ts +++ b/src/api/providers/__tests__/openrouter.spec.ts @@ -9,6 +9,8 @@ import OpenAI from "openai" import { OpenRouterHandler } from "../openrouter" import { ApiHandlerOptions } from "../../../shared/api" import { Package } from "../../../shared/package" +import { getModels } from "../fetchers/modelCache" +import { getModelEndpoints } from "../fetchers/modelEndpointCache" // Mock dependencies vitest.mock("openai") @@ -54,6 +56,9 @@ vitest.mock("../fetchers/modelCache", () => ({ }) }), })) +vitest.mock("../fetchers/modelEndpointCache", () => ({ + getModelEndpoints: vitest.fn().mockResolvedValue({}), +})) describe("OpenRouterHandler", () => { const mockOptions: ApiHandlerOptions = { @@ -78,6 +83,54 @@ describe("OpenRouterHandler", () => { }) }) + describe("getModel priority and caching", () => { + it("uses options.resolvedModelInfo when provided (persisted)", () => { + const handler = new OpenRouterHandler({ + openRouterApiKey: "test-key", + openRouterModelId: "anthropic/claude-sonnet-4", + resolvedModelInfo: { + maxTokens: 12345, + contextWindow: 99999, + supportsPromptCache: false, + } as any, + }) + + const logSpy = vitest.spyOn(console, "log").mockImplementation(() => {}) + const model = handler.getModel() + expect(model.id).toBe("anthropic/claude-sonnet-4") + expect(model.info.maxTokens).toBe(12345) + expect(logSpy).toHaveBeenCalledWith("[model-cache] source:", "persisted") + logSpy.mockRestore() + }) + + it("falls back to memory cache when persisted is absent", () => { + const handler = new OpenRouterHandler({ + openRouterApiKey: "test-key", + openRouterModelId: "custom/model", + }) + ;(handler as any).models = { + "custom/model": { maxTokens: 7777, contextWindow: 42424, supportsPromptCache: false }, + } + + const logSpy = vitest.spyOn(console, "log").mockImplementation(() => {}) + const model = handler.getModel() + expect(model.id).toBe("custom/model") + expect(model.info.maxTokens).toBe(7777) + expect(logSpy).toHaveBeenCalledWith("[model-cache] source:", "memory-cache") + logSpy.mockRestore() + }) + + it("falls back to openRouterDefaultModelInfo when both are absent", () => { + const handler = new OpenRouterHandler({}) + const logSpy = vitest.spyOn(console, "log").mockImplementation(() => {}) + const model = handler.getModel() + expect(model.id).toBe("anthropic/claude-sonnet-4") + expect(model.info.supportsPromptCache).toBe(true) + expect(logSpy).toHaveBeenCalledWith("[model-cache] source:", "default-fallback") + logSpy.mockRestore() + }) + }) + describe("fetchModel", () => { it("returns correct model info when options are provided", async () => { const handler = new OpenRouterHandler(mockOptions) diff --git a/src/api/providers/__tests__/requesty.spec.ts b/src/api/providers/__tests__/requesty.spec.ts index a8fabd40338a..4bdddac0ed79 100644 --- a/src/api/providers/__tests__/requesty.spec.ts +++ b/src/api/providers/__tests__/requesty.spec.ts @@ -237,3 +237,32 @@ describe("RequestyHandler", () => { }) }) }) + +// Phase 2: getModel priority tests +describe("RequestyHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over cache and default", () => { + const resolved = { maxTokens: 1234, contextWindow: 55555, supportsImages: false, supportsPromptCache: true } + const handler = new RequestyHandler({ + requestyModelId: "coding/claude-4-sonnet", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("coding/claude-4-sonnet") + expect(model.info).toBe(resolved) + }) + + it("uses memory cache when no resolvedModelInfo", () => { + const handler = new RequestyHandler({ requestyModelId: "router/model" } as any) + ;(handler as any).models = { + "router/model": { maxTokens: 2345, contextWindow: 64000, supportsImages: true, supportsPromptCache: false }, + } + const model = handler.getModel() + expect(model.info.maxTokens).toBe(2345) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new RequestyHandler({ requestyModelId: "unknown/model" } as any) + const model = handler.getModel() + expect(model.info).toEqual(expect.objectContaining({ contextWindow: expect.any(Number) })) + }) +}) diff --git a/src/api/providers/__tests__/roo.spec.ts b/src/api/providers/__tests__/roo.spec.ts index 7555a49d498e..6aef15775843 100644 --- a/src/api/providers/__tests__/roo.spec.ts +++ b/src/api/providers/__tests__/roo.spec.ts @@ -630,4 +630,49 @@ describe("RooHandler", () => { ) }) }) + + // Phase 2: getModel priority tests + describe("RooHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over shared cache and default", () => { + const resolved = { + maxTokens: 4096, + contextWindow: 131072, + supportsImages: false, + supportsReasoningEffort: true, + supportsPromptCache: true, + inputPrice: 0, + outputPrice: 0, + } + const handler = new RooHandler({ + apiModelId: "xai/grok-code-fast-1", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("xai/grok-code-fast-1") + expect(model.info).toBe(resolved) + }) + + it("uses shared cache when no resolvedModelInfo", async () => { + const handler = new RooHandler({ apiModelId: "xai/grok-code-fast-1" } as any) + const model = handler.getModel() + expect(model.info).toEqual( + expect.objectContaining({ + contextWindow: 262144, + supportsPromptCache: true, + }), + ) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new RooHandler({ apiModelId: "unknown/model" } as any) + const model = handler.getModel() + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 16384, + contextWindow: 262144, + supportsPromptCache: true, + }), + ) + }) + }) }) diff --git a/src/api/providers/__tests__/unbound.spec.ts b/src/api/providers/__tests__/unbound.spec.ts index fb52e2cb8cbb..5ce217f0024e 100644 --- a/src/api/providers/__tests__/unbound.spec.ts +++ b/src/api/providers/__tests__/unbound.spec.ts @@ -324,8 +324,44 @@ describe("UnboundHandler", () => { it("should return default model when invalid model provided", async () => { const handlerWithInvalidModel = new UnboundHandler({ ...mockOptions, unboundModelId: "invalid/model" }) const modelInfo = await handlerWithInvalidModel.fetchModel() - expect(modelInfo.id).toBe("anthropic/claude-sonnet-4-5") + // Priority now preserves requested id with default info + expect(modelInfo.id).toBe("invalid/model") expect(modelInfo.info).toBeDefined() }) }) }) + +// Phase 2: getModel priority tests +describe("UnboundHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over cache and default", () => { + const resolved = { maxTokens: 7777, contextWindow: 888888, supportsImages: false, supportsPromptCache: true } + const handler = new UnboundHandler({ + unboundApiKey: "k", + unboundModelId: "anthropic/claude-3-5-sonnet-20241022", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("anthropic/claude-3-5-sonnet-20241022") + expect(model.info).toBe(resolved) + }) + + it("uses memory cache when no resolvedModelInfo", () => { + const handler = new UnboundHandler({ unboundApiKey: "k", unboundModelId: "openai/gpt-4o" } as any) + ;(handler as any).models = { + "openai/gpt-4o": { + maxTokens: 9999, + contextWindow: 128000, + supportsImages: true, + supportsPromptCache: false, + }, + } + const model = handler.getModel() + expect(model.info.maxTokens).toBe(9999) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new UnboundHandler({ unboundApiKey: "k", unboundModelId: "unknown/model" } as any) + const model = handler.getModel() + expect(model.info).toEqual(expect.objectContaining({ contextWindow: expect.any(Number) })) + }) +}) diff --git a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts index f7c1d9d88518..865066db70aa 100644 --- a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts +++ b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts @@ -359,6 +359,52 @@ describe("VercelAiGatewayHandler", () => { }) }) + // Phase 2: getModel priority tests + describe("VercelAiGatewayHandler getModel priority", () => { + it("prefers options.resolvedModelInfo over cache and default", () => { + const resolved = { + maxTokens: 64000, + contextWindow: 200000, + supportsImages: true, + supportsPromptCache: true, + } + const handler = new VercelAiGatewayHandler({ + vercelAiGatewayApiKey: "k", + vercelAiGatewayModelId: "anthropic/claude-sonnet-4", + resolvedModelInfo: resolved, + } as any) + const model = handler.getModel() + expect(model.id).toBe("anthropic/claude-sonnet-4") + expect(model.info).toBe(resolved) + }) + + it("uses memory cache when no resolvedModelInfo", () => { + const handler = new VercelAiGatewayHandler({ + vercelAiGatewayApiKey: "k", + vercelAiGatewayModelId: "openai/gpt-4o", + } as any) + ;(handler as any).models = { + "openai/gpt-4o": { + maxTokens: 16000, + contextWindow: 128000, + supportsImages: true, + supportsPromptCache: false, + }, + } + const model = handler.getModel() + expect(model.info.maxTokens).toBe(16000) + }) + + it("falls back to default when neither persisted nor cache", () => { + const handler = new VercelAiGatewayHandler({ + vercelAiGatewayApiKey: "k", + vercelAiGatewayModelId: "unknown/model", + } as any) + const model = handler.getModel() + expect(model.info).toEqual(expect.objectContaining({ contextWindow: expect.any(Number) })) + }) + }) + describe("temperature support", () => { it("applies temperature for supported models", async () => { const handler = new VercelAiGatewayHandler({ diff --git a/src/api/providers/chutes.ts b/src/api/providers/chutes.ts index d19c55abcec7..a497757c01fb 100644 --- a/src/api/providers/chutes.ts +++ b/src/api/providers/chutes.ts @@ -61,7 +61,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const model = await this.fetchModel() + const model = this.getModel() if (model.id.includes("DeepSeek-R1")) { const stream = await this.client.chat.completions.create({ @@ -127,7 +127,7 @@ export class ChutesHandler extends RouterProvider implements SingleCompletionHan } async completePrompt(prompt: string): Promise { - const model = await this.fetchModel() + const model = this.getModel() const { id: modelId, info } = model try { diff --git a/src/api/providers/deepinfra.ts b/src/api/providers/deepinfra.ts index fb8c117ae013..a522ed834c37 100644 --- a/src/api/providers/deepinfra.ts +++ b/src/api/providers/deepinfra.ts @@ -40,7 +40,11 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion override getModel() { const id = this.options.deepInfraModelId ?? deepInfraDefaultModelId - const info = this.models[id] ?? deepInfraDefaultModelInfo + const info = this.options.resolvedModelInfo ?? this.models[id] ?? deepInfraDefaultModelInfo + console.log( + "[model-cache] source:", + this.options.resolvedModelInfo ? "persisted" : this.models[id] ? "memory-cache" : "default-fallback", + ) const params = getModelParams({ format: "openai", @@ -57,9 +61,8 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion messages: Anthropic.Messages.MessageParam[], _metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - // Ensure we have up-to-date model metadata - await this.fetchModel() - const { id: modelId, info, reasoningEffort: reasoning_effort } = await this.fetchModel() + // Use current model metadata synchronously + const { id: modelId, info, reasoningEffort: reasoning_effort } = this.getModel() let prompt_cache_key = undefined if (info.supportsPromptCache && _metadata?.taskId) { prompt_cache_key = _metadata.taskId @@ -107,7 +110,6 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion } async completePrompt(prompt: string): Promise { - await this.fetchModel() const { id: modelId, info } = this.getModel() const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 2a72ef1cc5f8..cc942484bd37 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -12,10 +12,25 @@ vi.mock("node-cache", () => { }) // Mock fs/promises to avoid file system operations -vi.mock("fs/promises", () => ({ - writeFile: vi.fn().mockResolvedValue(undefined), - readFile: vi.fn().mockResolvedValue("{}"), - mkdir: vi.fn().mockResolvedValue(undefined), +vi.mock("fs/promises", () => { + const mod = { + writeFile: vi.fn().mockResolvedValue(undefined), + readFile: vi.fn().mockResolvedValue("{}"), + mkdir: vi.fn().mockResolvedValue(undefined), + // Default to "file exists"; individual tests will override readFile content as needed + access: vi.fn().mockResolvedValue(undefined), + unlink: vi.fn().mockResolvedValue(undefined), + rename: vi.fn().mockResolvedValue(undefined), + } + return { ...mod, default: mod } +}) + +// Provide stable paths for caches during tests +vi.mock("../../../../core/config/ContextProxy", () => ({ + ContextProxy: { instance: { globalStorageUri: { fsPath: "/tmp" } } }, +})) +vi.mock("../../../../utils/storage", () => ({ + getCacheDirectoryPath: vi.fn().mockResolvedValue("/tmp/cache"), })) // Mock all the model fetchers @@ -28,7 +43,8 @@ vi.mock("../io-intelligence") // Then imports import type { Mock } from "vitest" -import { getModels } from "../modelCache" +import { getModels, flushModels } from "../modelCache" +import { flushModelProviders } from "../modelEndpointCache" import { getLiteLLMModels } from "../litellm" import { getOpenRouterModels } from "../openrouter" import { getRequestyModels } from "../requesty" @@ -48,8 +64,22 @@ const DUMMY_UNBOUND_KEY = "unbound-key-for-testing" const DUMMY_IOINTELLIGENCE_KEY = "io-intelligence-key-for-testing" describe("getModels with new GetModelsOptions", () => { - beforeEach(() => { - vi.clearAllMocks() + beforeEach(async () => { + vi.resetAllMocks() + + // Re-prime mocked storage/helper modules after resetAllMocks clears implementations + const storage = await import("../../../../utils/storage") + ;(storage.getCacheDirectoryPath as unknown as Mock).mockResolvedValue("/tmp/cache") + + const ctx = await import("../../../../core/config/ContextProxy") + ;(ctx as any).ContextProxy = { instance: { globalStorageUri: { fsPath: "/tmp" } } } + + // Ensure memory cache does not leak across tests + await Promise.all( + ["litellm", "openrouter", "requesty", "glama", "unbound", "io-intelligence"].map((r) => + flushModels(r as any), + ), + ) }) it("calls getLiteLLMModels with correct parameters", async () => { @@ -63,6 +93,9 @@ describe("getModels with new GetModelsOptions", () => { } mockGetLiteLLMModels.mockResolvedValue(mockModels) + const fsp = await import("fs/promises") + ;(fsp.readFile as unknown as Mock).mockResolvedValueOnce(JSON.stringify(mockModels)) + const result = await getModels({ provider: "litellm", apiKey: "test-api-key", @@ -84,6 +117,9 @@ describe("getModels with new GetModelsOptions", () => { } mockGetOpenRouterModels.mockResolvedValue(mockModels) + const fsp = await import("fs/promises") + ;(fsp.readFile as unknown as Mock).mockResolvedValueOnce(JSON.stringify(mockModels)) + const result = await getModels({ provider: "openrouter" }) expect(mockGetOpenRouterModels).toHaveBeenCalled() @@ -101,6 +137,9 @@ describe("getModels with new GetModelsOptions", () => { } mockGetRequestyModels.mockResolvedValue(mockModels) + const fsp = await import("fs/promises") + ;(fsp.readFile as unknown as Mock).mockResolvedValueOnce(JSON.stringify(mockModels)) + const result = await getModels({ provider: "requesty", apiKey: DUMMY_REQUESTY_KEY }) expect(mockGetRequestyModels).toHaveBeenCalledWith(undefined, DUMMY_REQUESTY_KEY) @@ -118,6 +157,9 @@ describe("getModels with new GetModelsOptions", () => { } mockGetGlamaModels.mockResolvedValue(mockModels) + const fsp = await import("fs/promises") + ;(fsp.readFile as unknown as Mock).mockResolvedValueOnce(JSON.stringify(mockModels)) + const result = await getModels({ provider: "glama" }) expect(mockGetGlamaModels).toHaveBeenCalled() @@ -135,6 +177,9 @@ describe("getModels with new GetModelsOptions", () => { } mockGetUnboundModels.mockResolvedValue(mockModels) + const fsp = await import("fs/promises") + ;(fsp.readFile as unknown as Mock).mockResolvedValueOnce(JSON.stringify(mockModels)) + const result = await getModels({ provider: "unbound", apiKey: DUMMY_UNBOUND_KEY }) expect(mockGetUnboundModels).toHaveBeenCalledWith(DUMMY_UNBOUND_KEY) @@ -152,13 +197,111 @@ describe("getModels with new GetModelsOptions", () => { } mockGetIOIntelligenceModels.mockResolvedValue(mockModels) + const fsp = await import("fs/promises") + ;(fsp.readFile as unknown as Mock).mockResolvedValueOnce(JSON.stringify(mockModels)) + const result = await getModels({ provider: "io-intelligence", apiKey: DUMMY_IOINTELLIGENCE_KEY }) expect(mockGetIOIntelligenceModels).toHaveBeenCalled() expect(result).toEqual(mockModels) }) + describe("explicit flush and no auto-expiration", () => { + it("flushModels clears memory and attempts to delete file cache", async () => { + const fsUtils = await import("../../../../utils/fs") + const existsSpy = vi.spyOn(fsUtils, "fileExistsAtPath").mockResolvedValue(true) + + const fsp = await import("fs/promises") + const def = (fsp as any).default ?? (fsp as any) + const unlink = def.unlink as unknown as Mock + unlink.mockClear() + + // Act + await flushModels("openrouter") + + // Assert file deletion attempted with expected filename pattern + expect(unlink).toHaveBeenCalled() + const [[calledPath]] = (unlink as unknown as { mock: { calls: [string][] } }).mock.calls + expect(String(calledPath)).toContain("openrouter_models.json") + + existsSpy.mockRestore() + }) + + it("flushModelProviders clears memory and attempts to delete endpoints file cache", async () => { + const fsUtils = await import("../../../../utils/fs") + const existsSpy = vi.spyOn(fsUtils, "fileExistsAtPath").mockResolvedValue(true) + + const fsp = await import("fs/promises") + const def = (fsp as any).default ?? (fsp as any) + const unlink = def.unlink as unknown as Mock + unlink.mockClear() + + await flushModelProviders("openrouter", "test-model") + + // Assert endpoints file deletion attempted with expected filename pattern + expect(unlink).toHaveBeenCalled() + const calls = (unlink as any).mock.calls.map((c: any[]) => String(c[0])) + expect(calls.some((p: string) => p.includes("openrouter_test-model_endpoints.json"))).toBe(true) + + existsSpy.mockRestore() + }) + + it("does not auto-expire cached entries after previous TTL window", async () => { + vi.useFakeTimers() + vi.resetModules() + + // Use real NodeCache for this re-import + vi.unmock("node-cache") + + const expectedModels = { + "test/model": { + maxTokens: 1024, + contextWindow: 8192, + supportsPromptCache: false, + }, + } + + // Lightweight mocks to avoid real FS and VSCode context + vi.doMock("fs/promises", () => ({ + writeFile: vi.fn().mockResolvedValue(undefined), + readFile: vi.fn().mockResolvedValue(JSON.stringify(expectedModels)), + mkdir: vi.fn().mockResolvedValue(undefined), + access: vi.fn().mockResolvedValue(undefined), + unlink: vi.fn().mockResolvedValue(undefined), + rename: vi.fn().mockResolvedValue(undefined), + })) + vi.doMock("../../../../utils/safeWriteJson", () => ({ + safeWriteJson: vi.fn().mockResolvedValue(undefined), + })) + vi.doMock("../../../../core/config/ContextProxy", () => ({ + ContextProxy: { instance: { globalStorageUri: { fsPath: "/tmp" } } }, + })) + vi.doMock("../../../../utils/storage", () => ({ + getCacheDirectoryPath: vi.fn().mockResolvedValue("/tmp/cache"), + })) + vi.doMock("../openrouter", () => ({ + getOpenRouterModels: vi.fn().mockResolvedValue(expectedModels), + })) + + const { getModels, getModelsFromCache } = await import("../modelCache") + + await getModels({ provider: "openrouter" }) + expect(getModelsFromCache("openrouter")).toEqual(expectedModels) + + // Advance beyond the old TTL (5 minutes) + vi.advanceTimersByTime(6 * 60 * 1000) + + // Value should still be present (no auto-expiry) + expect(getModelsFromCache("openrouter")).toEqual(expectedModels) + + vi.useRealTimers() + }) + }) + it("handles errors and re-throws them", async () => { + // Ensure no leftover implementation from previous tests + mockGetLiteLLMModels.mockReset() + const expectedError = new Error("LiteLLM connection failed") mockGetLiteLLMModels.mockRejectedValue(expectedError) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 722e66dd7286..8a9841abe1a4 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -27,7 +27,7 @@ import { getHuggingFaceModels } from "./huggingface" import { getRooModels } from "./roo" import { getChutesModels } from "./chutes" -const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) +const memoryCache = new NodeCache({ stdTTL: 0, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { const filename = `${router}_models.json` @@ -145,7 +145,19 @@ export const getModels = async (options: GetModelsOptions): Promise * @param router - The router to flush models for. */ export const flushModels = async (router: RouterName) => { + // Clear in-memory cache memoryCache.del(router) + + // Best-effort delete of persisted file cache (ignore ENOENT) + try { + const filename = `${router}_models.json` + const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) + const filePath = path.join(cacheDir, filename) + + await fs.unlink(filePath).catch(() => {}) + } catch (err) { + console.error(`[flushModels] failed to delete persisted cache for ${router}:`, err) + } } export function getModelsFromCache(provider: ProviderName) { diff --git a/src/api/providers/fetchers/modelEndpointCache.ts b/src/api/providers/fetchers/modelEndpointCache.ts index 256ae8404800..41aed343f26e 100644 --- a/src/api/providers/fetchers/modelEndpointCache.ts +++ b/src/api/providers/fetchers/modelEndpointCache.ts @@ -12,7 +12,7 @@ import { fileExistsAtPath } from "../../../utils/fs" import { getOpenRouterModelEndpoints } from "./openrouter" -const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) +const memoryCache = new NodeCache({ stdTTL: 0, checkperiod: 5 * 60 }) const getCacheKey = (router: RouterName, modelId: string) => sanitize(`${router}_${modelId}`) @@ -79,5 +79,19 @@ export const getModelEndpoints = async ({ return modelProviders ?? {} } -export const flushModelProviders = async (router: RouterName, modelId: string) => - memoryCache.del(getCacheKey(router, modelId)) +export const flushModelProviders = async (router: RouterName, modelId: string) => { + // Clear in-memory cache for this (router, modelId) key + const key = getCacheKey(router, modelId) + memoryCache.del(key) + + // Best-effort delete of persisted file cache (ignore ENOENT) + try { + const filename = `${key}_endpoints.json` + const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) + const filePath = path.join(cacheDir, filename) + + await fs.unlink(filePath).catch(() => {}) + } catch (err) { + console.error(`[flushModelProviders] failed to delete persisted endpoints cache for ${key}:`, err) + } +} diff --git a/src/api/providers/fetchers/openrouter.ts b/src/api/providers/fetchers/openrouter.ts index b546c40a3cfc..9b85d0da51ac 100644 --- a/src/api/providers/fetchers/openrouter.ts +++ b/src/api/providers/fetchers/openrouter.ts @@ -115,6 +115,9 @@ export async function getOpenRouterModels(options?: ApiHandlerOptions): Promise< continue } + console.log( + `[openrouter] fetched model ${id}: context_length=${model.context_length}, max_completion_tokens=${top_provider?.max_completion_tokens ?? "n/a"}`, + ) models[id] = parseOpenRouterModel({ id, model, @@ -161,6 +164,9 @@ export async function getOpenRouterModelEndpoints( } for (const endpoint of endpoints) { + console.log( + `[openrouter] fetched model ${id} endpoint ${endpoint.tag ?? endpoint.provider_name}: context_length=${endpoint.context_length}, max_completion_tokens=${endpoint.max_completion_tokens ?? "n/a"}`, + ) models[endpoint.tag ?? endpoint.provider_name] = parseOpenRouterModel({ id, model: endpoint, diff --git a/src/api/providers/glama.ts b/src/api/providers/glama.ts index 774d61570973..8b5a4f68642f 100644 --- a/src/api/providers/glama.ts +++ b/src/api/providers/glama.ts @@ -38,7 +38,7 @@ export class GlamaHandler extends RouterProvider implements SingleCompletionHand messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: modelId, info } = await this.fetchModel() + const { id: modelId, info } = this.getModel() const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, @@ -117,7 +117,7 @@ export class GlamaHandler extends RouterProvider implements SingleCompletionHand } async completePrompt(prompt: string): Promise { - const { id: modelId, info } = await this.fetchModel() + const { id: modelId, info } = this.getModel() try { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { diff --git a/src/api/providers/io-intelligence.ts b/src/api/providers/io-intelligence.ts index ef1c60a6a2c7..4a8851057550 100644 --- a/src/api/providers/io-intelligence.ts +++ b/src/api/providers/io-intelligence.ts @@ -23,14 +23,25 @@ export class IOIntelligenceHandler extends BaseOpenAiCompatibleProvider { - const { id: modelId, info } = await this.fetchModel() + const { id: modelId, info } = this.getModel() // Check if this is a GPT-5 model that requires max_completion_tokens instead of max_tokens const isGPT5Model = this.isGpt5(modelId) diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 580b17331194..3f799b72d79a 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -102,7 +102,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH systemPrompt: string, messages: Anthropic.Messages.MessageParam[], ): AsyncGenerator { - const model = await this.fetchModel() + const model = this.getModel() let { id: modelId, maxTokens, temperature, topP, reasoning } = model @@ -225,7 +225,15 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH override getModel() { const id = this.options.openRouterModelId ?? openRouterDefaultModelId - let info = this.models[id] ?? openRouterDefaultModelInfo + + // Priority: 1) persisted (resolvedModelInfo), 2) memory cache, 3) default fallback + console.log( + "[model-cache] source:", + this.options.resolvedModelInfo ? "persisted" : this.models[id] ? "memory-cache" : "default-fallback", + ) + + let info = + (this.options.resolvedModelInfo as any) ?? (this.models[id] as any) ?? (openRouterDefaultModelInfo as any) // If a specific provider is requested, use the endpoint for that provider. if (this.options.openRouterSpecificProvider && this.endpoints[this.options.openRouterSpecificProvider]) { @@ -246,7 +254,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } async completePrompt(prompt: string) { - let { id: modelId, maxTokens, temperature, reasoning } = await this.fetchModel() + let { id: modelId, maxTokens, temperature, reasoning } = this.getModel() const completionParams: OpenRouterChatCompletionParams = { model: modelId, diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index 1c0e9ed64075..970169aaf68e 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -67,7 +67,11 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan override getModel() { const id = this.options.requestyModelId ?? requestyDefaultModelId - const info = this.models[id] ?? requestyDefaultModelInfo + const info = this.options.resolvedModelInfo ?? this.models[id] ?? requestyDefaultModelInfo + console.log( + "[model-cache] source:", + this.options.resolvedModelInfo ? "persisted" : this.models[id] ? "memory-cache" : "default-fallback", + ) const params = getModelParams({ format: "anthropic", @@ -111,7 +115,7 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan temperature, reasoningEffort: reasoning_effort, reasoning: thinking, - } = await this.fetchModel() + } = this.getModel() const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, @@ -160,7 +164,7 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan } async completePrompt(prompt: string): Promise { - const { id: model, maxTokens: max_tokens, temperature } = await this.fetchModel() + const { id: model, maxTokens: max_tokens, temperature } = this.getModel() let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: prompt }] diff --git a/src/api/providers/roo.ts b/src/api/providers/roo.ts index 327796a1ffca..aa8939c08c98 100644 --- a/src/api/providers/roo.ts +++ b/src/api/providers/roo.ts @@ -194,15 +194,23 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { override getModel() { const modelId = this.options.apiModelId || rooDefaultModelId - // Get models from shared cache + // 1) Persisted + if (this.options.resolvedModelInfo) { + console.log("[model-cache] source:", "persisted") + return { id: modelId, info: this.options.resolvedModelInfo } + } + + // 2) Shared cache const models = getModelsFromCache("roo") || {} const modelInfo = models[modelId] + console.log("[model-cache] source:", modelInfo ? "memory-cache" : "default-fallback") + if (modelInfo) { return { id: modelId, info: modelInfo } } - // Return the requested model ID even if not found, with fallback info. + // 3) Fallback defaults return { id: modelId, info: { diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts index 25e9a11e1b2c..83f2533fff2d 100644 --- a/src/api/providers/router-provider.ts +++ b/src/api/providers/router-provider.ts @@ -62,10 +62,12 @@ export abstract class RouterProvider extends BaseProvider { override getModel(): { id: string; info: ModelInfo } { const id = this.modelId ?? this.defaultModelId - - return this.models[id] - ? { id, info: this.models[id] } - : { id: this.defaultModelId, info: this.defaultModelInfo } + const info = this.options.resolvedModelInfo ?? this.models[id] ?? this.defaultModelInfo + console.log( + "[model-cache] source:", + this.options.resolvedModelInfo ? "persisted" : this.models[id] ? "memory-cache" : "default-fallback", + ) + return { id, info } } protected supportsTemperature(modelId: string): boolean { diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index bc85dfd499f1..84db68c22f0b 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -57,7 +57,7 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: modelId, info } = await this.fetchModel() + const { id: modelId, info } = this.getModel() const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, @@ -133,7 +133,7 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa } async completePrompt(prompt: string): Promise { - const { id: modelId, info } = await this.fetchModel() + const { id: modelId, info } = this.getModel() try { const requestOptions: UnboundChatCompletionCreateParamsNonStreaming = { diff --git a/src/api/providers/vercel-ai-gateway.ts b/src/api/providers/vercel-ai-gateway.ts index be77d35986b4..df40f918c736 100644 --- a/src/api/providers/vercel-ai-gateway.ts +++ b/src/api/providers/vercel-ai-gateway.ts @@ -41,7 +41,7 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: modelId, info } = await this.fetchModel() + const { id: modelId, info } = this.getModel() const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ { role: "system", content: systemPrompt }, @@ -88,7 +88,7 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp } async completePrompt(prompt: string): Promise { - const { id: modelId, info } = await this.fetchModel() + const { id: modelId, info } = this.getModel() try { const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { diff --git a/src/core/sliding-window/index.ts b/src/core/sliding-window/index.ts index 1e518c9a56d9..f1ce53e92ab0 100644 --- a/src/core/sliding-window/index.ts +++ b/src/core/sliding-window/index.ts @@ -142,8 +142,17 @@ export async function truncateConversationIfNeeded({ } // If no specific threshold is found for the profile, fall back to global setting + // Debug: log context window and thresholds for sliding-window checks + console.log( + `[sliding-window] check: contextWindow=${contextWindow}, prevContextTokens=${prevContextTokens}, reservedTokens=${reservedTokens}, allowedTokens=${allowedTokens}, effectiveThreshold=${effectiveThreshold}, autoCondenseContext=${autoCondenseContext}`, + ) + if (autoCondenseContext) { const contextPercent = (100 * prevContextTokens) / contextWindow + // Debug: log auto-condense threshold check with current context window + console.log( + `[sliding-window] auto-condense check: contextWindow=${contextWindow}, contextPercent=${contextPercent.toFixed(2)}%, threshold=${effectiveThreshold}%, allowedTokens=${allowedTokens}`, + ) if (contextPercent >= effectiveThreshold || prevContextTokens > allowedTokens) { // Attempt to intelligently condense the context const result = await summarizeConversation( @@ -166,6 +175,10 @@ export async function truncateConversationIfNeeded({ } // Fall back to sliding window truncation if needed + // Debug: log fallback sliding-window check with current context window + console.log( + `[sliding-window] fallback check: contextWindow=${contextWindow}, prevContextTokens=${prevContextTokens}, allowedTokens=${allowedTokens}`, + ) if (prevContextTokens > allowedTokens) { const truncatedMessages = truncateConversation(messages, 0.5, taskId) return { messages: truncatedMessages, prevContextTokens, summary: "", cost, error } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index f73a36d445ee..2fc4ebb7fb95 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -42,6 +42,7 @@ import { ORGANIZATION_ALLOW_ALL, DEFAULT_MODES, DEFAULT_CHECKPOINT_TIMEOUT_SECONDS, + modelIdKeysByProvider, } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" import { CloudService, BridgeOrchestrator, getRooCodeApiUrl } from "@roo-code/cloud" @@ -1301,26 +1302,58 @@ export class ClineProvider activate: boolean = true, ): Promise { try { - // TODO: Do we need to be calling `activateProfile`? It's not - // clear to me what the source of truth should be; in some cases - // we rely on the `ContextProxy`'s data store and in other cases - // we rely on the `ProviderSettingsManager`'s data store. It might - // be simpler to unify these two. + // Read previous state for change detection + const prevState = await this.getState() + const prev = (prevState?.apiConfiguration ?? {}) as ProviderSettings + const next = providerSettings ?? ({} as ProviderSettings) + + // Determine relevant keys for change detection + const providerChanged = (prev.apiProvider || undefined) !== (next.apiProvider || undefined) + + const providerName = next.apiProvider as ProviderName | undefined + const modelKey = providerName + ? modelIdKeysByProvider[providerName as keyof typeof modelIdKeysByProvider] + : undefined + + const normalize = (v: unknown) => { + if (v === null || v === undefined) return undefined + const s = String(v).trim() + return s.length ? s : undefined + } + + const modelChanged = modelKey + ? normalize((prev as any)[modelKey]) !== normalize((next as any)[modelKey]) + : false + + // Base URL keys for router-compatible providers + const baseUrlKey: keyof ProviderSettings | undefined = (() => { + switch (providerName) { + case "openrouter": + return "openRouterBaseUrl" + case "requesty": + return "requestyBaseUrl" + case "litellm": + return "litellmBaseUrl" + case "deepinfra": + return "deepInfraBaseUrl" + default: + return undefined + } + })() + + const baseUrlChanged = baseUrlKey + ? normalize((prev as any)[baseUrlKey]) !== normalize((next as any)[baseUrlKey]) + : false + + const shouldReinit = providerChanged || modelChanged || baseUrlChanged + + // Persist configuration first const id = await this.providerSettingsManager.saveConfig(name, providerSettings) if (activate) { - const { mode } = await this.getState() - - // These promises do the following: - // 1. Adds or updates the list of provider profiles. - // 2. Sets the current provider profile. - // 3. Sets the current mode's provider profile. - // 4. Copies the provider settings to the context. - // - // Note: 1, 2, and 4 can be done in one `ContextProxy` call: - // this.contextProxy.setValues({ ...providerSettings, listApiConfigMeta: ..., currentApiConfigName: ... }) - // We should probably switch to that and verify that it works. - // I left the original implementation in just to be safe. + const { mode } = prevState + + // Keep state in sync regardless of reinit await Promise.all([ this.updateGlobalState("listApiConfigMeta", await this.providerSettingsManager.listConfig()), this.updateGlobalState("currentApiConfigName", name), @@ -1328,12 +1361,16 @@ export class ClineProvider this.contextProxy.setProviderSettings(providerSettings), ]) - // Change the provider for the current task. - // TODO: We should rename `buildApiHandler` for clarity (e.g. `getProviderClient`). - const task = this.getCurrentTask() - - if (task) { - task.api = buildApiHandler(providerSettings) + // Only rebuild API handler if relevant fields changed + if (shouldReinit) { + console.log("[model-cache/save] Reinit: relevant fields changed") + const task = this.getCurrentTask() + if (task) { + // Lightweight re-init (no forced fetch) + task.api = buildApiHandler(providerSettings) + } + } else { + console.log("[model-cache/save] No reinit: provider/model/baseUrl unchanged") } } else { await this.updateGlobalState("listApiConfigMeta", await this.providerSettingsManager.listConfig()) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index c06729674503..9f5e1d917bd8 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -52,6 +52,7 @@ import { openMention } from "../mentions" import { getWorkspacePath } from "../../utils/path" import { Mode, defaultModeSlug } from "../../shared/modes" import { getModels, flushModels } from "../../api/providers/fetchers/modelCache" +import { flushModelProviders } from "../../api/providers/fetchers/modelEndpointCache" import { GetModelsOptions } from "../../shared/api" import { generateSystemPrompt } from "./generateSystemPrompt" import { getCommand } from "../../utils/commands" @@ -750,10 +751,111 @@ export const webviewMessageHandler = async ( case "resetState": await provider.resetState() break - case "flushRouterModels": - const routerNameFlush: RouterName = toRouterName(message.text) - await flushModels(routerNameFlush) + case "flushRouterModels": { + try { + const { apiConfiguration, currentApiConfigName = "default" } = await provider.getState() + const providerName = apiConfiguration?.apiProvider + const router: RouterName = providerName ? toRouterName(providerName) : toRouterName(message.text) + + // Determine selected modelId from provider profile + let selectedModelId: string | undefined + try { + const { modelIdKeysByProvider } = await import("@roo-code/types") + const key = providerName ? (modelIdKeysByProvider as any)[providerName] : undefined + selectedModelId = key ? (apiConfiguration as any)[key] : (apiConfiguration as any)?.apiModelId + } catch { + selectedModelId = (apiConfiguration as any)?.apiModelId + } + + // Flush caches (memory + file) + await flushModels(router) + if (selectedModelId) { + await flushModelProviders(router, selectedModelId) + } + console.log("[model-cache/refresh] Flushed memory+file cache for", router) + + // Build options for refetch + const buildOptions = (): GetModelsOptions => { + switch (router) { + case "requesty": + return { + provider: "requesty", + apiKey: (apiConfiguration as any).requestyApiKey, + baseUrl: (apiConfiguration as any).requestyBaseUrl, + } + case "glama": + return { provider: "glama" } + case "unbound": + return { provider: "unbound", apiKey: (apiConfiguration as any).unboundApiKey } + case "litellm": + return { + provider: "litellm", + apiKey: (apiConfiguration as any).litellmApiKey, + baseUrl: (apiConfiguration as any).litellmBaseUrl, + } + case "deepinfra": + return { + provider: "deepinfra", + apiKey: (apiConfiguration as any).deepInfraApiKey, + baseUrl: (apiConfiguration as any).deepInfraBaseUrl, + } + case "io-intelligence": + return { + provider: "io-intelligence", + apiKey: (apiConfiguration as any).ioIntelligenceApiKey, + } + case "vercel-ai-gateway": + return { provider: "vercel-ai-gateway" } + case "openrouter": + return { provider: "openrouter" } + case "roo": + return { + provider: "roo", + baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy", + apiKey: CloudService.hasInstance() + ? CloudService.instance.authService?.getSessionToken() + : undefined, + } + case "chutes": + return { provider: "chutes", apiKey: (apiConfiguration as any).chutesApiKey } + case "ollama": + return { + provider: "ollama", + baseUrl: (apiConfiguration as any).ollamaBaseUrl, + apiKey: (apiConfiguration as any).ollamaApiKey, + } + case "lmstudio": + return { provider: "lmstudio", baseUrl: (apiConfiguration as any).lmStudioBaseUrl } + case "huggingface": + return { provider: "huggingface" } + default: + return { provider: router } + } + } + + // Refetch fresh models to warm caches + const options = buildOptions() + const models = await getModels(options) + + // Persist resolvedModelInfo for selected model if available + if (selectedModelId && models && models[selectedModelId]) { + const info = models[selectedModelId] as any + const updatedConfig = { ...apiConfiguration, resolvedModelInfo: info } + await provider.upsertProviderProfile(currentApiConfigName || "default", updatedConfig, true) + console.log("[model-cache/refresh] Persisted resolvedModelInfo for", router, selectedModelId) + } + + await provider.postMessageToWebview({ type: "flushRouterModelsResult", success: true }) + } catch (error) { + console.warn("[model-cache/refresh] Refresh failed:", error) + await provider.postMessageToWebview({ + type: "flushRouterModelsResult", + success: false, + error: error instanceof Error ? error.message : String(error), + }) + } break + } case "requestRouterModels": const { apiConfiguration } = await provider.getState() diff --git a/src/extension.ts b/src/extension.ts index bf0ceec02c29..911da47abe26 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -12,7 +12,8 @@ try { console.warn("Failed to load environment variables:", e) } -import type { CloudUserInfo, AuthState } from "@roo-code/types" +import type { CloudUserInfo, AuthState, ModelInfo } from "@roo-code/types" +import { isDynamicProvider } from "@roo-code/types" import { CloudService, BridgeOrchestrator } from "@roo-code/cloud" import { TelemetryService, PostHogTelemetryClient } from "@roo-code/telemetry" @@ -31,6 +32,7 @@ import { MdmService } from "./services/mdm/MdmService" import { migrateSettings } from "./utils/migrateSettings" import { autoImportSettings } from "./utils/autoImportSettings" import { API } from "./extension/api" +import { buildApiHandler } from "./api" import { handleUri, @@ -42,6 +44,60 @@ import { import { initializeI18n } from "./i18n" import { flushModels, getModels } from "./api/providers/fetchers/modelCache" +/** + * Phase 3: activation-time self-healing population of resolvedModelInfo + * Only activation wiring and persistence via existing API profile path. + */ +export async function ensureResolvedModelInfo(provider: ClineProvider): Promise { + try { + const state = await provider.getState() + const apiConfiguration = state.apiConfiguration + const providerName = apiConfiguration?.apiProvider + + // Process only dynamic providers + if (!providerName || !isDynamicProvider(providerName)) { + return + } + + // If resolvedModelInfo exists and is valid, skip + const existing = apiConfiguration?.resolvedModelInfo as ModelInfo | undefined + if (existing && typeof existing.contextWindow === "number" && typeof (existing as any).maxTokens === "number") { + console.log("[model-cache] Using existing resolvedModelInfo for", providerName) + return + } + + console.log("[model-cache] Populating resolvedModelInfo for", providerName) + + // Build handler and resolve model info (prefer fetchModel() if available) + const handler = buildApiHandler(apiConfiguration) + let info: ModelInfo | undefined + + const maybeFetch = (handler as any)?.fetchModel + if (typeof maybeFetch === "function") { + const fetched = await maybeFetch.call(handler) + if (fetched && typeof fetched === "object") { + if ("info" in fetched && (fetched as any).info) { + info = (fetched as any).info as ModelInfo + } else if ("contextWindow" in fetched) { + info = fetched as ModelInfo + } + } + } + + if (!info) { + info = handler.getModel().info + } + + if (info) { + const profileName = state.currentApiConfigName || "default" + const updatedConfig = { ...apiConfiguration, resolvedModelInfo: info } + // Persist via same path as settings saves + await provider.upsertProviderProfile(profileName, updatedConfig, true) + } + } catch (error) { + console.warn("[model-cache] Failed to populate resolvedModelInfo:", error) + } +} /** * Built using https://github.com/microsoft/vscode-webview-ui-toolkit * @@ -254,6 +310,9 @@ export async function activate(context: vscode.ExtensionContext) { ) } + // Activation-time self-healing for resolvedModelInfo (non-blocking) + void ensureResolvedModelInfo(provider) + registerCommands({ context, outputChannel, provider }) /** diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 7d2759c91905..eec2dd2b0811 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -112,6 +112,7 @@ export interface ExtensionMessage { | "authenticatedUser" | "condenseTaskContextResponse" | "singleRouterModelFetchResponse" + | "flushRouterModelsResult" | "indexingStatusUpdate" | "indexCleared" | "codebaseIndexConfig" diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index e2e7ba561573..11b4ade3a95f 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -232,12 +232,6 @@ const ApiOptions = ({ vscode.postMessage({ type: "requestLmStudioModels" }) } else if (selectedProvider === "vscode-lm") { vscode.postMessage({ type: "requestVsCodeLmModels" }) - } else if ( - selectedProvider === "litellm" || - selectedProvider === "deepinfra" || - selectedProvider === "roo" - ) { - vscode.postMessage({ type: "requestRouterModels" }) } }, 250, diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 6020a260bd36..eadec984b2d4 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -25,6 +25,8 @@ import { useEscapeKey } from "@src/hooks/useEscapeKey" import { ModelInfoView } from "./ModelInfoView" import { ApiErrorMessage } from "./ApiErrorMessage" +import { vscode } from "@src/utils/vscode" +import type { ExtensionMessage } from "@roo/ExtensionMessage" type ModelIdKey = keyof Pick< ProviderSettings, @@ -71,6 +73,8 @@ export const ModelPicker = ({ const [open, setOpen] = useState(false) const [isDescriptionExpanded, setIsDescriptionExpanded] = useState(false) + const [refreshStatus, setRefreshStatus] = useState<"idle" | "loading" | "success" | "error">("idle") + const [refreshError, setRefreshError] = useState() const isInitialized = useRef(false) const searchInputRef = useRef(null) const selectTimeoutRef = useRef(null) @@ -112,6 +116,12 @@ export const ModelPicker = ({ setOpen(false) setApiConfigurationField(modelIdKey, modelId) + // Persist resolvedModelInfo immediately if available from cached models + if (models && models[modelId]) { + setApiConfigurationField("resolvedModelInfo", models[modelId], false) + console.log("[model-cache/ui] persisted resolvedModelInfo from cached models") + } + // Clear any existing timeout if (selectTimeoutRef.current) { clearTimeout(selectTimeoutRef.current) @@ -120,7 +130,7 @@ export const ModelPicker = ({ // Delay to ensure the popover is closed before setting the search value. selectTimeoutRef.current = setTimeout(() => setSearchValue(""), 100) }, - [modelIdKey, setApiConfigurationField], + [modelIdKey, setApiConfigurationField, models], ) const onOpenChange = useCallback((open: boolean) => { @@ -152,6 +162,26 @@ export const ModelPicker = ({ isInitialized.current = true }, [modelIds, setApiConfigurationField, modelIdKey, selectedModelId, defaultModelId]) + // Listen for refresh result messages + useEffect(() => { + const handler = (event: MessageEvent) => { + const message = event.data as ExtensionMessage + if (message.type === "flushRouterModelsResult") { + if (message.success) { + setRefreshStatus("success") + setRefreshError(undefined) + // Reset after brief success indication + setTimeout(() => setRefreshStatus("idle"), 1500) + } else { + setRefreshStatus("error") + setRefreshError(message.error || "Refresh failed") + } + } + } + window.addEventListener("message", handler) + return () => window.removeEventListener("message", handler) + }, []) + // Cleanup timeouts on unmount to prevent test flakiness useEffect(() => { return () => { @@ -170,7 +200,26 @@ export const ModelPicker = ({ return ( <>
- +
+ + +
+ {refreshStatus === "error" && refreshError && }