Skip to content

Commit 54db83b

Browse files
committed
Don't immediately show an model ID error when changing API providers
1 parent 2c40224 commit 54db83b

File tree

5 files changed

+219
-159
lines changed

5 files changed

+219
-159
lines changed

webview-ui/src/components/chat/ChatView.tsx

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import { forwardRef, useCallback, useEffect, useImperativeHandle, useMemo, useRe
44
import { useDeepCompareEffect, useEvent, useMount } from "react-use"
55
import { Virtuoso, type VirtuosoHandle } from "react-virtuoso"
66
import styled from "styled-components"
7+
import removeMd from "remove-markdown"
8+
import { Trans } from "react-i18next"
9+
710
import {
811
ClineAsk,
912
ClineMessage,
@@ -16,25 +19,27 @@ import { findLast } from "@roo/shared/array"
1619
import { combineApiRequests } from "@roo/shared/combineApiRequests"
1720
import { combineCommandSequences } from "@roo/shared/combineCommandSequences"
1821
import { getApiMetrics } from "@roo/shared/getApiMetrics"
22+
import { AudioType } from "@roo/shared/WebviewMessage"
23+
import { getAllModes } from "@roo/shared/modes"
24+
1925
import { useExtensionState } from "@src/context/ExtensionStateContext"
2026
import { vscode } from "@src/utils/vscode"
27+
import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
28+
import { validateCommand } from "@src/utils/command-validation"
29+
import { useAppTranslation } from "@src/i18n/TranslationContext"
30+
31+
import TelemetryBanner from "../common/TelemetryBanner"
2132
import HistoryPreview from "../history/HistoryPreview"
2233
import RooHero from "../welcome/RooHero"
23-
import { normalizeApiConfiguration } from "../settings/ApiOptions"
34+
2435
import Announcement from "./Announcement"
2536
import BrowserSessionRow from "./BrowserSessionRow"
2637
import ChatRow from "./ChatRow"
2738
import ChatTextArea from "./ChatTextArea"
2839
import TaskHeader from "./TaskHeader"
2940
import AutoApproveMenu from "./AutoApproveMenu"
3041
import SystemPromptWarning from "./SystemPromptWarning"
31-
import { AudioType } from "@roo/shared/WebviewMessage"
32-
import { validateCommand } from "@src/utils/command-validation"
33-
import { getAllModes } from "@roo/shared/modes"
34-
import TelemetryBanner from "../common/TelemetryBanner"
35-
import { useAppTranslation } from "@/i18n/TranslationContext"
36-
import removeMd from "remove-markdown"
37-
import { Trans } from "react-i18next"
42+
3843
interface ChatViewProps {
3944
isHidden: boolean
4045
showAnnouncement: boolean

webview-ui/src/components/chat/TaskHeader.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ import { CloudUpload, CloudDownload } from "lucide-react"
66

77
import { ClineMessage } from "@roo/shared/ExtensionMessage"
88

9-
import { getMaxTokensForModel } from "@/utils/model-utils"
10-
import { formatLargeNumber } from "@/utils/format"
11-
import { cn } from "@/lib/utils"
12-
import { Button } from "@/components/ui"
9+
import { getMaxTokensForModel } from "@src/utils/model-utils"
10+
import { formatLargeNumber } from "@src/utils/format"
11+
import { cn } from "@src/lib/utils"
12+
import { Button } from "@src/components/ui"
1313
import { useExtensionState } from "@src/context/ExtensionStateContext"
14+
import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
1415

1516
import Thumbnails from "../common/Thumbnails"
16-
import { normalizeApiConfiguration } from "../settings/ApiOptions"
1717

1818
import { TaskActions } from "./TaskActions"
1919
import { ContextWindowProgress } from "./ContextWindowProgress"

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 55 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -11,48 +11,34 @@ import { ExternalLinkIcon } from "@radix-ui/react-icons"
1111
import {
1212
ApiConfiguration,
1313
ModelInfo,
14-
anthropicDefaultModelId,
15-
anthropicModels,
1614
azureOpenAiDefaultApiVersion,
17-
bedrockDefaultModelId,
18-
bedrockModels,
19-
deepSeekDefaultModelId,
20-
deepSeekModels,
21-
geminiDefaultModelId,
22-
geminiModels,
2315
glamaDefaultModelId,
2416
glamaDefaultModelInfo,
2517
mistralDefaultModelId,
26-
mistralModels,
2718
openAiModelInfoSaneDefaults,
28-
openAiNativeDefaultModelId,
29-
openAiNativeModels,
3019
openRouterDefaultModelId,
3120
openRouterDefaultModelInfo,
32-
vertexDefaultModelId,
33-
vertexModels,
3421
unboundDefaultModelId,
3522
unboundDefaultModelInfo,
3623
requestyDefaultModelId,
3724
requestyDefaultModelInfo,
38-
xaiDefaultModelId,
39-
xaiModels,
4025
ApiProvider,
41-
vscodeLlmModels,
42-
vscodeLlmDefaultModelId,
4326
} from "@roo/shared/api"
4427
import { ExtensionMessage } from "@roo/shared/ExtensionMessage"
28+
import { AWS_REGIONS } from "@roo/shared/aws_regions"
4529

46-
import { vscode } from "@/utils/vscode"
47-
import { validateApiConfiguration, validateModelId, validateBedrockArn } from "@/utils/validate"
30+
import { vscode } from "@src/utils/vscode"
31+
import { validateApiConfiguration, validateModelId, validateBedrockArn } from "@src/utils/validate"
32+
import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
4833
import {
4934
useOpenRouterModelProviders,
5035
OPENROUTER_DEFAULT_PROVIDER_NAME,
51-
} from "@/components/ui/hooks/useOpenRouterModelProviders"
52-
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, SelectSeparator, Button } from "@/components/ui"
53-
import { MODELS_BY_PROVIDER, PROVIDERS, VERTEX_REGIONS, REASONING_MODELS } from "./constants"
54-
import { AWS_REGIONS } from "@roo/shared/aws_regions"
36+
} from "@src/components/ui/hooks/useOpenRouterModelProviders"
37+
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, Button } from "@src/components/ui"
38+
5539
import { VSCodeButtonLink } from "../common/VSCodeButtonLink"
40+
41+
import { MODELS_BY_PROVIDER, PROVIDERS, VERTEX_REGIONS, REASONING_MODELS } from "./constants"
5642
import { ModelInfoView } from "./ModelInfoView"
5743
import { ModelPicker } from "./ModelPicker"
5844
import { TemperatureControl } from "./TemperatureControl"
@@ -281,6 +267,7 @@ const ApiOptions = ({
281267
// Helper function to get the documentation URL and name for the currently selected provider
282268
const getSelectedProviderDocUrl = (): { url: string; name: string } | undefined => {
283269
const displayName = getProviderDisplayName(selectedProvider)
270+
284271
if (!displayName) {
285272
return undefined
286273
}
@@ -294,6 +281,49 @@ const ApiOptions = ({
294281
}
295282
}
296283

284+
const onApiProviderChange = useCallback(
285+
(value: ApiProvider) => {
286+
// It would be much easier to have a single attribute that stores
287+
// the modelId, but we have a separate attribute for each of
288+
// OpenRouter, Glama, Unbound, and Requesty.
289+
// If you switch to one of these providers and the corresponding
290+
// modelId is not set then you immediately end up in an error state.
291+
// To address that we set the modelId to the default value for th
292+
// provider if it's not already set.
293+
switch (value) {
294+
case "openrouter":
295+
if (!apiConfiguration.openRouterModelId) {
296+
setApiConfigurationField("openRouterModelId", openRouterDefaultModelId)
297+
}
298+
break
299+
case "glama":
300+
if (!apiConfiguration.glamaModelId) {
301+
setApiConfigurationField("glamaModelId", glamaDefaultModelId)
302+
}
303+
break
304+
case "unbound":
305+
if (!apiConfiguration.unboundModelId) {
306+
setApiConfigurationField("unboundModelId", unboundDefaultModelId)
307+
}
308+
break
309+
case "requesty":
310+
if (!apiConfiguration.requestyModelId) {
311+
setApiConfigurationField("requestyModelId", requestyDefaultModelId)
312+
}
313+
break
314+
}
315+
316+
setApiConfigurationField("apiProvider", value)
317+
},
318+
[
319+
setApiConfigurationField,
320+
apiConfiguration.openRouterModelId,
321+
apiConfiguration.glamaModelId,
322+
apiConfiguration.unboundModelId,
323+
apiConfiguration.requestyModelId,
324+
],
325+
)
326+
297327
return (
298328
<div className="flex flex-col gap-3">
299329
<div className="flex flex-col gap-1 relative">
@@ -312,16 +342,12 @@ const ApiOptions = ({
312342
</div>
313343
)}
314344
</div>
315-
<Select
316-
value={selectedProvider}
317-
onValueChange={(value) => setApiConfigurationField("apiProvider", value as ApiProvider)}>
345+
<Select value={selectedProvider} onValueChange={(value) => onApiProviderChange(value as ApiProvider)}>
318346
<SelectTrigger className="w-full">
319347
<SelectValue placeholder={t("settings:common.select")} />
320348
</SelectTrigger>
321349
<SelectContent>
322-
<SelectItem value="openrouter">OpenRouter</SelectItem>
323-
<SelectSeparator />
324-
{PROVIDERS.filter((p) => p.value !== "openrouter").map(({ value, label }) => (
350+
{PROVIDERS.map(({ value, label }) => (
325351
<SelectItem key={value} value={value}>
326352
{label}
327353
</SelectItem>
@@ -1738,113 +1764,4 @@ const ApiOptions = ({
17381764
)
17391765
}
17401766

1741-
export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
1742-
const provider = apiConfiguration?.apiProvider || "anthropic"
1743-
const modelId = apiConfiguration?.apiModelId
1744-
const getProviderData = (models: Record<string, ModelInfo>, defaultId: string) => {
1745-
let selectedModelId: string
1746-
let selectedModelInfo: ModelInfo
1747-
1748-
if (modelId && modelId in models) {
1749-
selectedModelId = modelId
1750-
selectedModelInfo = models[modelId]
1751-
} else {
1752-
selectedModelId = defaultId
1753-
selectedModelInfo = models[defaultId]
1754-
}
1755-
1756-
return { selectedProvider: provider, selectedModelId, selectedModelInfo }
1757-
}
1758-
1759-
switch (provider) {
1760-
case "anthropic":
1761-
return getProviderData(anthropicModels, anthropicDefaultModelId)
1762-
case "xai":
1763-
return getProviderData(xaiModels, xaiDefaultModelId)
1764-
case "bedrock":
1765-
// Special case for custom ARN
1766-
if (modelId === "custom-arn") {
1767-
return {
1768-
selectedProvider: provider,
1769-
selectedModelId: "custom-arn",
1770-
selectedModelInfo: {
1771-
maxTokens: 5000,
1772-
contextWindow: 128_000,
1773-
supportsPromptCache: false,
1774-
supportsImages: true,
1775-
},
1776-
}
1777-
}
1778-
return getProviderData(bedrockModels, bedrockDefaultModelId)
1779-
case "vertex":
1780-
return getProviderData(vertexModels, vertexDefaultModelId)
1781-
case "gemini":
1782-
return getProviderData(geminiModels, geminiDefaultModelId)
1783-
case "deepseek":
1784-
return getProviderData(deepSeekModels, deepSeekDefaultModelId)
1785-
case "openai-native":
1786-
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
1787-
case "mistral":
1788-
return getProviderData(mistralModels, mistralDefaultModelId)
1789-
case "openrouter":
1790-
return {
1791-
selectedProvider: provider,
1792-
selectedModelId: apiConfiguration?.openRouterModelId || openRouterDefaultModelId,
1793-
selectedModelInfo: apiConfiguration?.openRouterModelInfo || openRouterDefaultModelInfo,
1794-
}
1795-
case "glama":
1796-
return {
1797-
selectedProvider: provider,
1798-
selectedModelId: apiConfiguration?.glamaModelId || glamaDefaultModelId,
1799-
selectedModelInfo: apiConfiguration?.glamaModelInfo || glamaDefaultModelInfo,
1800-
}
1801-
case "unbound":
1802-
return {
1803-
selectedProvider: provider,
1804-
selectedModelId: apiConfiguration?.unboundModelId || unboundDefaultModelId,
1805-
selectedModelInfo: apiConfiguration?.unboundModelInfo || unboundDefaultModelInfo,
1806-
}
1807-
case "requesty":
1808-
return {
1809-
selectedProvider: provider,
1810-
selectedModelId: apiConfiguration?.requestyModelId || requestyDefaultModelId,
1811-
selectedModelInfo: apiConfiguration?.requestyModelInfo || requestyDefaultModelInfo,
1812-
}
1813-
case "openai":
1814-
return {
1815-
selectedProvider: provider,
1816-
selectedModelId: apiConfiguration?.openAiModelId || "",
1817-
selectedModelInfo: apiConfiguration?.openAiCustomModelInfo || openAiModelInfoSaneDefaults,
1818-
}
1819-
case "ollama":
1820-
return {
1821-
selectedProvider: provider,
1822-
selectedModelId: apiConfiguration?.ollamaModelId || "",
1823-
selectedModelInfo: openAiModelInfoSaneDefaults,
1824-
}
1825-
case "lmstudio":
1826-
return {
1827-
selectedProvider: provider,
1828-
selectedModelId: apiConfiguration?.lmStudioModelId || "",
1829-
selectedModelInfo: openAiModelInfoSaneDefaults,
1830-
}
1831-
case "vscode-lm":
1832-
const modelFamily = apiConfiguration?.vsCodeLmModelSelector?.family ?? vscodeLlmDefaultModelId
1833-
const modelInfo = {
1834-
...openAiModelInfoSaneDefaults,
1835-
...vscodeLlmModels[modelFamily as keyof typeof vscodeLlmModels],
1836-
supportsImages: false, // VSCode LM API currently doesn't support images.
1837-
}
1838-
return {
1839-
selectedProvider: provider,
1840-
selectedModelId: apiConfiguration?.vsCodeLmModelSelector
1841-
? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}`
1842-
: "",
1843-
selectedModelInfo: modelInfo,
1844-
}
1845-
default:
1846-
return getProviderData(anthropicModels, anthropicDefaultModelId)
1847-
}
1848-
}
1849-
18501767
export default memo(ApiOptions)

webview-ui/src/components/settings/ModelPicker.tsx

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import { ChevronsUpDown, Check, X } from "lucide-react"
55

66
import { ProviderSettings, ModelInfo } from "@roo/schemas"
77

8-
import { useAppTranslation } from "@/i18n/TranslationContext"
9-
import { cn } from "@/lib/utils"
8+
import { useAppTranslation } from "@src/i18n/TranslationContext"
9+
import { normalizeApiConfiguration } from "@src/utils/normalizeApiConfiguration"
10+
import { cn } from "@src/lib/utils"
1011
import {
1112
Command,
1213
CommandEmpty,
@@ -18,9 +19,8 @@ import {
1819
PopoverContent,
1920
PopoverTrigger,
2021
Button,
21-
} from "@/components/ui"
22+
} from "@src/components/ui"
2223

23-
import { normalizeApiConfiguration } from "./ApiOptions"
2424
import { ThinkingBudget } from "./ThinkingBudget"
2525
import { ModelInfoView } from "./ModelInfoView"
2626

@@ -205,10 +205,7 @@ export const ModelPicker = ({
205205
serviceLink: <VSCodeLink href={serviceUrl} className="text-sm" />,
206206
defaultModelLink: <VSCodeLink onClick={() => onSelect(defaultModelId)} className="text-sm" />,
207207
}}
208-
values={{
209-
serviceName,
210-
defaultModelId,
211-
}}
208+
values={{ serviceName, defaultModelId }}
212209
/>
213210
</div>
214211
</>

0 commit comments

Comments
 (0)