Skip to content

Commit b8505fe

Browse files
authored
fix: ambiguous model id error (#4306)
1 parent fca4bea commit b8505fe

File tree

14 files changed

+509
-92
lines changed

14 files changed

+509
-92
lines changed

webview-ui/src/components/settings/ApiErrorMessage.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ interface ApiErrorMessageProps {
66
}
77

88
export const ApiErrorMessage = ({ errorMessage, children }: ApiErrorMessageProps) => (
9-
<div className="flex flex-col gap-2 text-vscode-errorForeground text-sm">
9+
<div className="flex flex-col gap-2 text-vscode-errorForeground text-sm" data-testid="api-error-message">
1010
<div className="flex flex-row items-center gap-1">
1111
<div className="codicon codicon-close" />
1212
<div>{errorMessage}</div>

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 84 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,20 @@ import {
1111
glamaDefaultModelId,
1212
unboundDefaultModelId,
1313
litellmDefaultModelId,
14+
openAiNativeDefaultModelId,
15+
anthropicDefaultModelId,
16+
geminiDefaultModelId,
17+
deepSeekDefaultModelId,
18+
mistralDefaultModelId,
19+
xaiDefaultModelId,
20+
groqDefaultModelId,
21+
chutesDefaultModelId,
22+
bedrockDefaultModelId,
23+
vertexDefaultModelId,
1424
} from "@roo-code/types"
1525

1626
import { vscode } from "@src/utils/vscode"
17-
import { validateApiConfiguration } from "@src/utils/validate"
27+
import { validateApiConfigurationExcludingModelErrors, getModelValidationError } from "@src/utils/validate"
1828
import { useAppTranslation } from "@src/i18n/TranslationContext"
1929
import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
2030
import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel"
@@ -176,8 +186,11 @@ const ApiOptions = ({
176186
)
177187

178188
useEffect(() => {
179-
const apiValidationResult = validateApiConfiguration(apiConfiguration, routerModels, organizationAllowList)
180-
189+
const apiValidationResult = validateApiConfigurationExcludingModelErrors(
190+
apiConfiguration,
191+
routerModels,
192+
organizationAllowList,
193+
)
181194
setErrorMessage(apiValidationResult)
182195
}, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage])
183196

@@ -187,63 +200,90 @@ const ApiOptions = ({
187200

188201
const filteredModels = filterModels(models, selectedProvider, organizationAllowList)
189202

190-
return filteredModels
203+
const modelOptions = filteredModels
191204
? Object.keys(filteredModels).map((modelId) => ({
192205
value: modelId,
193206
label: modelId,
194207
}))
195208
: []
209+
210+
return modelOptions
196211
}, [selectedProvider, organizationAllowList])
197212

198213
const onProviderChange = useCallback(
199214
(value: ProviderName) => {
215+
setApiConfigurationField("apiProvider", value)
216+
200217
// It would be much easier to have a single attribute that stores
201218
// the modelId, but we have a separate attribute for each of
202219
// OpenRouter, Glama, Unbound, and Requesty.
203220
// If you switch to one of these providers and the corresponding
204221
// modelId is not set then you immediately end up in an error state.
205222
// To address that we set the modelId to the default value for th
206223
// provider if it's not already set.
207-
switch (value) {
208-
case "openrouter":
209-
if (!apiConfiguration.openRouterModelId) {
210-
setApiConfigurationField("openRouterModelId", openRouterDefaultModelId)
211-
}
212-
break
213-
case "glama":
214-
if (!apiConfiguration.glamaModelId) {
215-
setApiConfigurationField("glamaModelId", glamaDefaultModelId)
216-
}
217-
break
218-
case "unbound":
219-
if (!apiConfiguration.unboundModelId) {
220-
setApiConfigurationField("unboundModelId", unboundDefaultModelId)
221-
}
222-
break
223-
case "requesty":
224-
if (!apiConfiguration.requestyModelId) {
225-
setApiConfigurationField("requestyModelId", requestyDefaultModelId)
226-
}
227-
break
228-
case "litellm":
229-
if (!apiConfiguration.litellmModelId) {
230-
setApiConfigurationField("litellmModelId", litellmDefaultModelId)
224+
const validateAndResetModel = (
225+
modelId: string | undefined,
226+
field: keyof ProviderSettings,
227+
defaultValue?: string,
228+
) => {
229+
// in case we haven't set a default value for a provider
230+
if (!defaultValue) return
231+
232+
// only set default if no model is set, but don't reset invalid models
233+
// let users see and decide what to do with invalid model selections
234+
const shouldSetDefault = !modelId
235+
236+
if (shouldSetDefault) {
237+
setApiConfigurationField(field, defaultValue)
238+
}
239+
}
240+
241+
// Define a mapping object that associates each provider with its model configuration
242+
const PROVIDER_MODEL_CONFIG: Partial<
243+
Record<
244+
ProviderName,
245+
{
246+
field: keyof ProviderSettings
247+
default?: string
231248
}
232-
break
249+
>
250+
> = {
251+
openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId },
252+
glama: { field: "glamaModelId", default: glamaDefaultModelId },
253+
unbound: { field: "unboundModelId", default: unboundDefaultModelId },
254+
requesty: { field: "requestyModelId", default: requestyDefaultModelId },
255+
litellm: { field: "litellmModelId", default: litellmDefaultModelId },
256+
anthropic: { field: "apiModelId", default: anthropicDefaultModelId },
257+
"openai-native": { field: "apiModelId", default: openAiNativeDefaultModelId },
258+
gemini: { field: "apiModelId", default: geminiDefaultModelId },
259+
deepseek: { field: "apiModelId", default: deepSeekDefaultModelId },
260+
mistral: { field: "apiModelId", default: mistralDefaultModelId },
261+
xai: { field: "apiModelId", default: xaiDefaultModelId },
262+
groq: { field: "apiModelId", default: groqDefaultModelId },
263+
chutes: { field: "apiModelId", default: chutesDefaultModelId },
264+
bedrock: { field: "apiModelId", default: bedrockDefaultModelId },
265+
vertex: { field: "apiModelId", default: vertexDefaultModelId },
266+
openai: { field: "openAiModelId" },
267+
ollama: { field: "ollamaModelId" },
268+
lmstudio: { field: "lmStudioModelId" },
233269
}
234270

235-
setApiConfigurationField("apiProvider", value)
271+
const config = PROVIDER_MODEL_CONFIG[value]
272+
if (config) {
273+
validateAndResetModel(
274+
apiConfiguration[config.field] as string | undefined,
275+
config.field,
276+
config.default,
277+
)
278+
}
236279
},
237-
[
238-
setApiConfigurationField,
239-
apiConfiguration.openRouterModelId,
240-
apiConfiguration.glamaModelId,
241-
apiConfiguration.unboundModelId,
242-
apiConfiguration.requestyModelId,
243-
apiConfiguration.litellmModelId,
244-
],
280+
[setApiConfigurationField, apiConfiguration],
245281
)
246282

283+
const modelValidationError = useMemo(() => {
284+
return getModelValidationError(apiConfiguration, routerModels, organizationAllowList)
285+
}, [apiConfiguration, routerModels, organizationAllowList])
286+
247287
const docs = useMemo(() => {
248288
const provider = PROVIDERS.find(({ value }) => value === selectedProvider)
249289
const name = provider?.label
@@ -303,6 +343,7 @@ const ApiOptions = ({
303343
uriScheme={uriScheme}
304344
fromWelcomeView={fromWelcomeView}
305345
organizationAllowList={organizationAllowList}
346+
modelValidationError={modelValidationError}
306347
/>
307348
)}
308349

@@ -313,6 +354,7 @@ const ApiOptions = ({
313354
routerModels={routerModels}
314355
refetchRouterModels={refetchRouterModels}
315356
organizationAllowList={organizationAllowList}
357+
modelValidationError={modelValidationError}
316358
/>
317359
)}
318360

@@ -323,6 +365,7 @@ const ApiOptions = ({
323365
routerModels={routerModels}
324366
uriScheme={uriScheme}
325367
organizationAllowList={organizationAllowList}
368+
modelValidationError={modelValidationError}
326369
/>
327370
)}
328371

@@ -332,6 +375,7 @@ const ApiOptions = ({
332375
setApiConfigurationField={setApiConfigurationField}
333376
routerModels={routerModels}
334377
organizationAllowList={organizationAllowList}
378+
modelValidationError={modelValidationError}
335379
/>
336380
)}
337381

@@ -368,6 +412,7 @@ const ApiOptions = ({
368412
apiConfiguration={apiConfiguration}
369413
setApiConfigurationField={setApiConfigurationField}
370414
organizationAllowList={organizationAllowList}
415+
modelValidationError={modelValidationError}
371416
/>
372417
)}
373418

@@ -404,6 +449,7 @@ const ApiOptions = ({
404449
apiConfiguration={apiConfiguration}
405450
setApiConfigurationField={setApiConfigurationField}
406451
organizationAllowList={organizationAllowList}
452+
modelValidationError={modelValidationError}
407453
/>
408454
)}
409455

webview-ui/src/components/settings/ModelPicker.tsx

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {
2323
} from "@src/components/ui"
2424

2525
import { ModelInfoView } from "./ModelInfoView"
26+
import { ApiErrorMessage } from "./ApiErrorMessage"
2627

2728
type ModelIdKey = keyof Pick<
2829
ProviderSettings,
@@ -38,6 +39,7 @@ interface ModelPickerProps {
3839
apiConfiguration: ProviderSettings
3940
setApiConfigurationField: <K extends keyof ProviderSettings>(field: K, value: ProviderSettings[K]) => void
4041
organizationAllowList: OrganizationAllowList
42+
errorMessage?: string
4143
}
4244

4345
export const ModelPicker = ({
@@ -49,6 +51,7 @@ export const ModelPicker = ({
4951
apiConfiguration,
5052
setApiConfigurationField,
5153
organizationAllowList,
54+
errorMessage,
5255
}: ModelPickerProps) => {
5356
const { t } = useAppTranslation()
5457

@@ -119,7 +122,8 @@ export const ModelPicker = ({
119122
variant="combobox"
120123
role="combobox"
121124
aria-expanded={open}
122-
className="w-full justify-between">
125+
className="w-full justify-between"
126+
data-testid="model-picker-button">
123127
<div>{selectedModelId ?? t("settings:common.select")}</div>
124128
<ChevronsUpDown className="opacity-50" />
125129
</Button>
@@ -154,7 +158,11 @@ export const ModelPicker = ({
154158
</CommandEmpty>
155159
<CommandGroup>
156160
{modelIds.map((model) => (
157-
<CommandItem key={model} value={model} onSelect={onSelect}>
161+
<CommandItem
162+
key={model}
163+
value={model}
164+
onSelect={onSelect}
165+
data-testid={`model-option-${model}`}>
158166
{model}
159167
<Check
160168
className={cn(
@@ -177,6 +185,7 @@ export const ModelPicker = ({
177185
</PopoverContent>
178186
</Popover>
179187
</div>
188+
{errorMessage && <ApiErrorMessage errorMessage={errorMessage} />}
180189
{selectedModelId && selectedModelInfo && (
181190
<ModelInfoView
182191
apiProvider={apiConfiguration.apiProvider}

0 commit comments

Comments
 (0)