Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion webview-ui/src/components/settings/ApiErrorMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ interface ApiErrorMessageProps {
}

export const ApiErrorMessage = ({ errorMessage, children }: ApiErrorMessageProps) => (
<div className="flex flex-col gap-2 text-vscode-errorForeground text-sm">
<div className="flex flex-col gap-2 text-vscode-errorForeground text-sm" data-testid="api-error-message">
<div className="flex flex-row items-center gap-1">
<div className="codicon codicon-close" />
<div>{errorMessage}</div>
Expand Down
122 changes: 84 additions & 38 deletions webview-ui/src/components/settings/ApiOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,20 @@ import {
glamaDefaultModelId,
unboundDefaultModelId,
litellmDefaultModelId,
openAiNativeDefaultModelId,
anthropicDefaultModelId,
geminiDefaultModelId,
deepSeekDefaultModelId,
mistralDefaultModelId,
xaiDefaultModelId,
groqDefaultModelId,
chutesDefaultModelId,
bedrockDefaultModelId,
vertexDefaultModelId,
} from "@roo-code/types"

import { vscode } from "@src/utils/vscode"
import { validateApiConfiguration } from "@src/utils/validate"
import { validateApiConfigurationExcludingModelErrors, getModelValidationError } from "@src/utils/validate"
import { useAppTranslation } from "@src/i18n/TranslationContext"
import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel"
Expand Down Expand Up @@ -176,8 +186,11 @@ const ApiOptions = ({
)

useEffect(() => {
const apiValidationResult = validateApiConfiguration(apiConfiguration, routerModels, organizationAllowList)

const apiValidationResult = validateApiConfigurationExcludingModelErrors(
apiConfiguration,
routerModels,
organizationAllowList,
)
setErrorMessage(apiValidationResult)
}, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage])

Expand All @@ -187,63 +200,90 @@ const ApiOptions = ({

const filteredModels = filterModels(models, selectedProvider, organizationAllowList)

return filteredModels
const modelOptions = filteredModels
? Object.keys(filteredModels).map((modelId) => ({
value: modelId,
label: modelId,
}))
: []

return modelOptions
}, [selectedProvider, organizationAllowList])

const onProviderChange = useCallback(
(value: ProviderName) => {
setApiConfigurationField("apiProvider", value)

// It would be much easier to have a single attribute that stores
// the modelId, but we have a separate attribute for each of
// OpenRouter, Glama, Unbound, and Requesty.
// If you switch to one of these providers and the corresponding
// modelId is not set then you immediately end up in an error state.
// To address that we set the modelId to the default value for th
// provider if it's not already set.
switch (value) {
case "openrouter":
if (!apiConfiguration.openRouterModelId) {
setApiConfigurationField("openRouterModelId", openRouterDefaultModelId)
}
break
case "glama":
if (!apiConfiguration.glamaModelId) {
setApiConfigurationField("glamaModelId", glamaDefaultModelId)
}
break
case "unbound":
if (!apiConfiguration.unboundModelId) {
setApiConfigurationField("unboundModelId", unboundDefaultModelId)
}
break
case "requesty":
if (!apiConfiguration.requestyModelId) {
setApiConfigurationField("requestyModelId", requestyDefaultModelId)
}
break
case "litellm":
if (!apiConfiguration.litellmModelId) {
setApiConfigurationField("litellmModelId", litellmDefaultModelId)
const validateAndResetModel = (
modelId: string | undefined,
field: keyof ProviderSettings,
defaultValue?: string,
) => {
// in case we haven't set a default value for a provider
if (!defaultValue) return

// only set default if no model is set, but don't reset invalid models
// let users see and decide what to do with invalid model selections
const shouldSetDefault = !modelId

if (shouldSetDefault) {
setApiConfigurationField(field, defaultValue)
}
}

// Define a mapping object that associates each provider with its model configuration
const PROVIDER_MODEL_CONFIG: Partial<
Record<
ProviderName,
{
field: keyof ProviderSettings
default?: string
}
break
>
> = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably pull this out into a constant and strengthen the typing using this: https://github.com/RooCodeInc/Roo-Code/blob/main/packages/types/src/provider-settings.ts#L263

openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId },
glama: { field: "glamaModelId", default: glamaDefaultModelId },
unbound: { field: "unboundModelId", default: unboundDefaultModelId },
requesty: { field: "requestyModelId", default: requestyDefaultModelId },
litellm: { field: "litellmModelId", default: litellmDefaultModelId },
anthropic: { field: "apiModelId", default: anthropicDefaultModelId },
"openai-native": { field: "apiModelId", default: openAiNativeDefaultModelId },
gemini: { field: "apiModelId", default: geminiDefaultModelId },
deepseek: { field: "apiModelId", default: deepSeekDefaultModelId },
mistral: { field: "apiModelId", default: mistralDefaultModelId },
xai: { field: "apiModelId", default: xaiDefaultModelId },
groq: { field: "apiModelId", default: groqDefaultModelId },
chutes: { field: "apiModelId", default: chutesDefaultModelId },
bedrock: { field: "apiModelId", default: bedrockDefaultModelId },
vertex: { field: "apiModelId", default: vertexDefaultModelId },
openai: { field: "openAiModelId" },
ollama: { field: "ollamaModelId" },
lmstudio: { field: "lmStudioModelId" },
}

setApiConfigurationField("apiProvider", value)
const config = PROVIDER_MODEL_CONFIG[value]
if (config) {
validateAndResetModel(
apiConfiguration[config.field] as string | undefined,
config.field,
config.default,
)
}
},
[
setApiConfigurationField,
apiConfiguration.openRouterModelId,
apiConfiguration.glamaModelId,
apiConfiguration.unboundModelId,
apiConfiguration.requestyModelId,
apiConfiguration.litellmModelId,
],
[setApiConfigurationField, apiConfiguration],
)

const modelValidationError = useMemo(() => {
return getModelValidationError(apiConfiguration, routerModels, organizationAllowList)
}, [apiConfiguration, routerModels, organizationAllowList])

const docs = useMemo(() => {
const provider = PROVIDERS.find(({ value }) => value === selectedProvider)
const name = provider?.label
Expand Down Expand Up @@ -303,6 +343,7 @@ const ApiOptions = ({
uriScheme={uriScheme}
fromWelcomeView={fromWelcomeView}
organizationAllowList={organizationAllowList}
modelValidationError={modelValidationError}
/>
)}

Expand All @@ -313,6 +354,7 @@ const ApiOptions = ({
routerModels={routerModels}
refetchRouterModels={refetchRouterModels}
organizationAllowList={organizationAllowList}
modelValidationError={modelValidationError}
/>
)}

Expand All @@ -323,6 +365,7 @@ const ApiOptions = ({
routerModels={routerModels}
uriScheme={uriScheme}
organizationAllowList={organizationAllowList}
modelValidationError={modelValidationError}
/>
)}

Expand All @@ -332,6 +375,7 @@ const ApiOptions = ({
setApiConfigurationField={setApiConfigurationField}
routerModels={routerModels}
organizationAllowList={organizationAllowList}
modelValidationError={modelValidationError}
/>
)}

Expand Down Expand Up @@ -368,6 +412,7 @@ const ApiOptions = ({
apiConfiguration={apiConfiguration}
setApiConfigurationField={setApiConfigurationField}
organizationAllowList={organizationAllowList}
modelValidationError={modelValidationError}
/>
)}

Expand Down Expand Up @@ -404,6 +449,7 @@ const ApiOptions = ({
apiConfiguration={apiConfiguration}
setApiConfigurationField={setApiConfigurationField}
organizationAllowList={organizationAllowList}
modelValidationError={modelValidationError}
/>
)}

Expand Down
13 changes: 11 additions & 2 deletions webview-ui/src/components/settings/ModelPicker.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
} from "@src/components/ui"

import { ModelInfoView } from "./ModelInfoView"
import { ApiErrorMessage } from "./ApiErrorMessage"

type ModelIdKey = keyof Pick<
ProviderSettings,
Expand All @@ -38,6 +39,7 @@ interface ModelPickerProps {
apiConfiguration: ProviderSettings
setApiConfigurationField: <K extends keyof ProviderSettings>(field: K, value: ProviderSettings[K]) => void
organizationAllowList: OrganizationAllowList
errorMessage?: string
}

export const ModelPicker = ({
Expand All @@ -49,6 +51,7 @@ export const ModelPicker = ({
apiConfiguration,
setApiConfigurationField,
organizationAllowList,
errorMessage,
}: ModelPickerProps) => {
const { t } = useAppTranslation()

Expand Down Expand Up @@ -119,7 +122,8 @@ export const ModelPicker = ({
variant="combobox"
role="combobox"
aria-expanded={open}
className="w-full justify-between">
className="w-full justify-between"
data-testid="model-picker-button">
<div>{selectedModelId ?? t("settings:common.select")}</div>
<ChevronsUpDown className="opacity-50" />
</Button>
Expand Down Expand Up @@ -154,7 +158,11 @@ export const ModelPicker = ({
</CommandEmpty>
<CommandGroup>
{modelIds.map((model) => (
<CommandItem key={model} value={model} onSelect={onSelect}>
<CommandItem
key={model}
value={model}
onSelect={onSelect}
data-testid={`model-option-${model}`}>
{model}
<Check
className={cn(
Expand All @@ -177,6 +185,7 @@ export const ModelPicker = ({
</PopoverContent>
</Popover>
</div>
{errorMessage && <ApiErrorMessage errorMessage={errorMessage} />}
{selectedModelId && selectedModelInfo && (
<ModelInfoView
apiProvider={apiConfiguration.apiProvider}
Expand Down
Loading
Loading