diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index 2a1a3f986aa5..3db2c7fb1005 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -31,3 +31,121 @@ export * from "./vercel-ai-gateway.js" export * from "./zai.js" export * from "./deepinfra.js" export * from "./minimax.js" + +import { anthropicDefaultModelId } from "./anthropic.js" +import { bedrockDefaultModelId } from "./bedrock.js" +import { cerebrasDefaultModelId } from "./cerebras.js" +import { chutesDefaultModelId } from "./chutes.js" +import { claudeCodeDefaultModelId } from "./claude-code.js" +import { deepSeekDefaultModelId } from "./deepseek.js" +import { doubaoDefaultModelId } from "./doubao.js" +import { featherlessDefaultModelId } from "./featherless.js" +import { fireworksDefaultModelId } from "./fireworks.js" +import { geminiDefaultModelId } from "./gemini.js" +import { glamaDefaultModelId } from "./glama.js" +import { groqDefaultModelId } from "./groq.js" +import { ioIntelligenceDefaultModelId } from "./io-intelligence.js" +import { litellmDefaultModelId } from "./lite-llm.js" +import { mistralDefaultModelId } from "./mistral.js" +import { moonshotDefaultModelId } from "./moonshot.js" +import { openRouterDefaultModelId } from "./openrouter.js" +import { qwenCodeDefaultModelId } from "./qwen-code.js" +import { requestyDefaultModelId } from "./requesty.js" +import { rooDefaultModelId } from "./roo.js" +import { sambaNovaDefaultModelId } from "./sambanova.js" +import { unboundDefaultModelId } from "./unbound.js" +import { vertexDefaultModelId } from "./vertex.js" +import { vscodeLlmDefaultModelId } from "./vscode-llm.js" +import { xaiDefaultModelId } from "./xai.js" +import { vercelAiGatewayDefaultModelId } from "./vercel-ai-gateway.js" +import { internationalZAiDefaultModelId, mainlandZAiDefaultModelId } from "./zai.js" +import { deepInfraDefaultModelId } from "./deepinfra.js" +import { minimaxDefaultModelId } from "./minimax.js" + +// Import the ProviderName type from provider-settings to avoid duplication +import type { ProviderName } from "../provider-settings.js" + +/** + * Get the default model ID for a given provider. + * This function returns only the provider's default model ID, without considering user configuration. + * Used as a fallback when provider models are still loading. + */ +export function getProviderDefaultModelId( + provider: ProviderName, + options: { isChina?: boolean } = { isChina: false }, +): string { + switch (provider) { + case "openrouter": + return openRouterDefaultModelId + case "requesty": + return requestyDefaultModelId + case "glama": + return glamaDefaultModelId + case "unbound": + return unboundDefaultModelId + case "litellm": + return litellmDefaultModelId + case "xai": + return xaiDefaultModelId + case "groq": + return groqDefaultModelId + case "huggingface": + return "meta-llama/Llama-3.3-70B-Instruct" + case "chutes": + return chutesDefaultModelId + case "bedrock": + return bedrockDefaultModelId + case "vertex": + return vertexDefaultModelId + case "gemini": + return geminiDefaultModelId + case "deepseek": + return deepSeekDefaultModelId + case "doubao": + return doubaoDefaultModelId + case "moonshot": + return moonshotDefaultModelId + case "minimax": + return minimaxDefaultModelId + case "zai": + return options?.isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId + case "openai-native": + return "gpt-4o" // Based on openai-native patterns + case "mistral": + return mistralDefaultModelId + case "openai": + return "" // OpenAI provider uses custom model configuration + case "ollama": + return "" // Ollama uses dynamic model selection + case "lmstudio": + return "" // LMStudio uses dynamic model selection + case "deepinfra": + return deepInfraDefaultModelId + case "vscode-lm": + return vscodeLlmDefaultModelId + case "claude-code": + return claudeCodeDefaultModelId + case "cerebras": + return cerebrasDefaultModelId + case "sambanova": + return sambaNovaDefaultModelId + case "fireworks": + return fireworksDefaultModelId + case "featherless": + return featherlessDefaultModelId + case "io-intelligence": + return ioIntelligenceDefaultModelId + case "roo": + return rooDefaultModelId + case "qwen-code": + return qwenCodeDefaultModelId + case "vercel-ai-gateway": + return vercelAiGatewayDefaultModelId + case "anthropic": + case "gemini-cli": + case "human-relay": + case "fake-ai": + default: + return anthropicDefaultModelId + } +} diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index c27d2f19f2f8..945b8aa63b21 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -93,7 +93,7 @@ describe("useSelectedModel", () => { }) }) - it("should use only specific provider info when base model info is missing", () => { + it("should fall back to default when configured model doesn't exist in available models", () => { const specificProviderInfo: ModelInfo = { maxTokens: 8192, contextWindow: 16384, @@ -106,7 +106,18 @@ describe("useSelectedModel", () => { mockUseRouterModels.mockReturnValue({ data: { - openrouter: {}, + openrouter: { + "anthropic/claude-sonnet-4.5": { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 3.0, + outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, + }, + }, requesty: {}, glama: {}, unbound: {}, @@ -127,15 +138,29 @@ describe("useSelectedModel", () => { const apiConfiguration: ProviderSettings = { apiProvider: "openrouter", - openRouterModelId: "test-model", + openRouterModelId: "test-model", // This model doesn't exist in available models openRouterSpecificProvider: "test-provider", } const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) - expect(result.current.id).toBe("test-model") - expect(result.current.info).toEqual(specificProviderInfo) + // Should fall back to provider default since "test-model" doesn't exist + expect(result.current.id).toBe("anthropic/claude-sonnet-4.5") + // Should still use specific provider info for the default model if specified + expect(result.current.info).toEqual({ + ...{ + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 3.0, + outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, + }, + ...specificProviderInfo, + }) }) it("should demonstrate the merging behavior validates the comment about missing fields", () => { @@ -244,12 +269,12 @@ describe("useSelectedModel", () => { expect(result.current.info).toEqual(baseModelInfo) }) - it("should fall back to default when both base and specific provider info are missing", () => { + it("should fall back to default when configured model and provider don't exist", () => { mockUseRouterModels.mockReturnValue({ data: { openrouter: { - "anthropic/claude-sonnet-4": { - // Default model + "anthropic/claude-sonnet-4.5": { + // Default model - using correct default model name maxTokens: 8192, contextWindow: 200_000, supportsImages: true, @@ -285,8 +310,19 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) - expect(result.current.id).toBe("non-existent-model") - expect(result.current.info).toBeUndefined() + // Should fall back to provider default since "non-existent-model" doesn't exist + expect(result.current.id).toBe("anthropic/claude-sonnet-4.5") + // Should use base model info since provider doesn't exist + expect(result.current.info).toEqual({ + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 3.0, + outputPrice: 15.0, + cacheWritesPrice: 3.75, + cacheReadsPrice: 0.3, + }) }) }) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 296b262c3731..c2a57942d24e 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -2,62 +2,33 @@ import { type ProviderName, type ProviderSettings, type ModelInfo, - anthropicDefaultModelId, anthropicModels, - bedrockDefaultModelId, bedrockModels, - cerebrasDefaultModelId, cerebrasModels, - deepSeekDefaultModelId, deepSeekModels, - moonshotDefaultModelId, moonshotModels, - minimaxDefaultModelId, minimaxModels, - geminiDefaultModelId, geminiModels, - mistralDefaultModelId, mistralModels, openAiModelInfoSaneDefaults, - openAiNativeDefaultModelId, openAiNativeModels, - vertexDefaultModelId, vertexModels, - xaiDefaultModelId, xaiModels, groqModels, - groqDefaultModelId, - chutesDefaultModelId, vscodeLlmModels, vscodeLlmDefaultModelId, - openRouterDefaultModelId, - requestyDefaultModelId, - glamaDefaultModelId, - unboundDefaultModelId, - litellmDefaultModelId, - claudeCodeDefaultModelId, claudeCodeModels, sambaNovaModels, - sambaNovaDefaultModelId, doubaoModels, - doubaoDefaultModelId, - internationalZAiDefaultModelId, - mainlandZAiDefaultModelId, internationalZAiModels, mainlandZAiModels, fireworksModels, - fireworksDefaultModelId, featherlessModels, - featherlessDefaultModelId, - ioIntelligenceDefaultModelId, ioIntelligenceModels, - rooDefaultModelId, - qwenCodeDefaultModelId, qwenCodeModels, - vercelAiGatewayDefaultModelId, BEDROCK_1M_CONTEXT_MODEL_IDS, - deepInfraDefaultModelId, isDynamicProvider, + getProviderDefaultModelId, } from "@roo-code/types" import type { ModelRecord, RouterModels } from "@roo/api" @@ -67,6 +38,18 @@ import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders" import { useLmStudioModels } from "./useLmStudioModels" import { useOllamaModels } from "./useOllamaModels" +/** + * Helper to get a validated model ID for dynamic providers. + * Returns the configured model ID if it exists in the available models, otherwise returns the default. + */ +function getValidatedModelId( + configuredId: string | undefined, + availableModels: ModelRecord | undefined, + defaultModelId: string, +): string { + return configuredId && availableModels?.[configuredId] ? configuredId : defaultModelId +} + export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const provider = apiConfiguration?.apiProvider || "anthropic" const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined @@ -90,10 +73,17 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const needLmStudio = typeof lmStudioModelId !== "undefined" const needOllama = typeof ollamaModelId !== "undefined" + const hasValidRouterData = needRouterModels + ? routerModels.data && + routerModels.data[provider] !== undefined && + typeof routerModels.data[provider] === "object" && + !routerModels.isLoading + : true + const isReady = (!needLmStudio || typeof lmStudioModels.data !== "undefined") && (!needOllama || typeof ollamaModels.data !== "undefined") && - (!needRouterModels || typeof routerModels.data !== "undefined") && + hasValidRouterData && (!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined") const { id, info } = @@ -106,7 +96,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined, ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined, }) - : { id: anthropicDefaultModelId, info: undefined } + : { id: getProviderDefaultModelId(provider), info: undefined } return { provider, @@ -143,10 +133,11 @@ function getSelectedModel({ // the `undefined` case are used to show the invalid selection to prevent // users from seeing the default model if their selection is invalid // this gives a better UX than showing the default model + const defaultModelId = getProviderDefaultModelId(provider) switch (provider) { case "openrouter": { - const id = apiConfiguration.openRouterModelId ?? openRouterDefaultModelId - let info = routerModels.openrouter[id] + const id = getValidatedModelId(apiConfiguration.openRouterModelId, routerModels.openrouter, defaultModelId) + let info = routerModels.openrouter?.[id] const specificProvider = apiConfiguration.openRouterSpecificProvider if (specificProvider && openRouterModelProviders[specificProvider]) { @@ -161,32 +152,32 @@ function getSelectedModel({ return { id, info } } case "requesty": { - const id = apiConfiguration.requestyModelId ?? requestyDefaultModelId - const info = routerModels.requesty[id] + const id = getValidatedModelId(apiConfiguration.requestyModelId, routerModels.requesty, defaultModelId) + const info = routerModels.requesty?.[id] return { id, info } } case "glama": { - const id = apiConfiguration.glamaModelId ?? glamaDefaultModelId - const info = routerModels.glama[id] + const id = getValidatedModelId(apiConfiguration.glamaModelId, routerModels.glama, defaultModelId) + const info = routerModels.glama?.[id] return { id, info } } case "unbound": { - const id = apiConfiguration.unboundModelId ?? unboundDefaultModelId - const info = routerModels.unbound[id] + const id = getValidatedModelId(apiConfiguration.unboundModelId, routerModels.unbound, defaultModelId) + const info = routerModels.unbound?.[id] return { id, info } } case "litellm": { - const id = apiConfiguration.litellmModelId ?? litellmDefaultModelId - const info = routerModels.litellm[id] + const id = getValidatedModelId(apiConfiguration.litellmModelId, routerModels.litellm, defaultModelId) + const info = routerModels.litellm?.[id] return { id, info } } case "xai": { - const id = apiConfiguration.apiModelId ?? xaiDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = xaiModels[id as keyof typeof xaiModels] return info ? { id, info } : { id, info: undefined } } case "groq": { - const id = apiConfiguration.apiModelId ?? groqDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = groqModels[id as keyof typeof groqModels] return { id, info } } @@ -201,12 +192,12 @@ function getSelectedModel({ return { id, info } } case "chutes": { - const id = apiConfiguration.apiModelId ?? chutesDefaultModelId - const info = routerModels.chutes[id] + const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.chutes, defaultModelId) + const info = routerModels.chutes?.[id] return { id, info } } case "bedrock": { - const id = apiConfiguration.apiModelId ?? bedrockDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const baseInfo = bedrockModels[id as keyof typeof bedrockModels] // Special case for custom ARN. @@ -230,50 +221,50 @@ function getSelectedModel({ return { id, info: baseInfo } } case "vertex": { - const id = apiConfiguration.apiModelId ?? vertexDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = vertexModels[id as keyof typeof vertexModels] return { id, info } } case "gemini": { - const id = apiConfiguration.apiModelId ?? geminiDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = geminiModels[id as keyof typeof geminiModels] return { id, info } } case "deepseek": { - const id = apiConfiguration.apiModelId ?? deepSeekDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = deepSeekModels[id as keyof typeof deepSeekModels] return { id, info } } case "doubao": { - const id = apiConfiguration.apiModelId ?? doubaoDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = doubaoModels[id as keyof typeof doubaoModels] return { id, info } } case "moonshot": { - const id = apiConfiguration.apiModelId ?? moonshotDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = moonshotModels[id as keyof typeof moonshotModels] return { id, info } } case "minimax": { - const id = apiConfiguration.apiModelId ?? minimaxDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = minimaxModels[id as keyof typeof minimaxModels] return { id, info } } case "zai": { const isChina = apiConfiguration.zaiApiLine === "china_coding" const models = isChina ? mainlandZAiModels : internationalZAiModels - const defaultModelId = isChina ? mainlandZAiDefaultModelId : internationalZAiDefaultModelId + const defaultModelId = getProviderDefaultModelId(provider, { isChina }) const id = apiConfiguration.apiModelId ?? defaultModelId const info = models[id as keyof typeof models] return { id, info } } case "openai-native": { - const id = apiConfiguration.apiModelId ?? openAiNativeDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = openAiNativeModels[id as keyof typeof openAiNativeModels] return { id, info } } case "mistral": { - const id = apiConfiguration.apiModelId ?? mistralDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = mistralModels[id as keyof typeof mistralModels] return { id, info } } @@ -307,7 +298,7 @@ function getSelectedModel({ } } case "deepinfra": { - const id = apiConfiguration.deepInfraModelId ?? deepInfraDefaultModelId + const id = getValidatedModelId(apiConfiguration.deepInfraModelId, routerModels.deepinfra, defaultModelId) const info = routerModels.deepinfra?.[id] return { id, info } } @@ -321,49 +312,56 @@ function getSelectedModel({ } case "claude-code": { // Claude Code models extend anthropic models but with images and prompt caching disabled - const id = apiConfiguration.apiModelId ?? claudeCodeDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = claudeCodeModels[id as keyof typeof claudeCodeModels] return { id, info: { ...openAiModelInfoSaneDefaults, ...info } } } case "cerebras": { - const id = apiConfiguration.apiModelId ?? cerebrasDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = cerebrasModels[id as keyof typeof cerebrasModels] return { id, info } } case "sambanova": { - const id = apiConfiguration.apiModelId ?? sambaNovaDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = sambaNovaModels[id as keyof typeof sambaNovaModels] return { id, info } } case "fireworks": { - const id = apiConfiguration.apiModelId ?? fireworksDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = fireworksModels[id as keyof typeof fireworksModels] return { id, info } } case "featherless": { - const id = apiConfiguration.apiModelId ?? featherlessDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = featherlessModels[id as keyof typeof featherlessModels] return { id, info } } case "io-intelligence": { - const id = apiConfiguration.ioIntelligenceModelId ?? ioIntelligenceDefaultModelId + const id = getValidatedModelId( + apiConfiguration.ioIntelligenceModelId, + routerModels["io-intelligence"], + defaultModelId, + ) const info = routerModels["io-intelligence"]?.[id] ?? ioIntelligenceModels[id as keyof typeof ioIntelligenceModels] return { id, info } } case "roo": { - // Roo is a dynamic provider - models are loaded from API - const id = apiConfiguration.apiModelId ?? rooDefaultModelId - const info = routerModels.roo[id] + const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.roo, defaultModelId) + const info = routerModels.roo?.[id] return { id, info } } case "qwen-code": { - const id = apiConfiguration.apiModelId ?? qwenCodeDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const info = qwenCodeModels[id as keyof typeof qwenCodeModels] return { id, info } } case "vercel-ai-gateway": { - const id = apiConfiguration.vercelAiGatewayModelId ?? vercelAiGatewayDefaultModelId + const id = getValidatedModelId( + apiConfiguration.vercelAiGatewayModelId, + routerModels["vercel-ai-gateway"], + defaultModelId, + ) const info = routerModels["vercel-ai-gateway"]?.[id] return { id, info } } @@ -372,7 +370,7 @@ function getSelectedModel({ // case "fake-ai": default: { provider satisfies "anthropic" | "gemini-cli" | "qwen-code" | "human-relay" | "fake-ai" - const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId + const id = apiConfiguration.apiModelId ?? defaultModelId const baseInfo = anthropicModels[id as keyof typeof anthropicModels] // Apply 1M context beta tier pricing for Claude Sonnet 4