diff --git a/webview-ui/src/components/settings/ApiErrorMessage.tsx b/webview-ui/src/components/settings/ApiErrorMessage.tsx index 06764a1bfa..5e14edcfff 100644 --- a/webview-ui/src/components/settings/ApiErrorMessage.tsx +++ b/webview-ui/src/components/settings/ApiErrorMessage.tsx @@ -6,7 +6,7 @@ interface ApiErrorMessageProps { } export const ApiErrorMessage = ({ errorMessage, children }: ApiErrorMessageProps) => ( -
+
{errorMessage}
diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 905f34a860..c55999efbd 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -11,10 +11,20 @@ import { glamaDefaultModelId, unboundDefaultModelId, litellmDefaultModelId, + openAiNativeDefaultModelId, + anthropicDefaultModelId, + geminiDefaultModelId, + deepSeekDefaultModelId, + mistralDefaultModelId, + xaiDefaultModelId, + groqDefaultModelId, + chutesDefaultModelId, + bedrockDefaultModelId, + vertexDefaultModelId, } from "@roo-code/types" import { vscode } from "@src/utils/vscode" -import { validateApiConfiguration } from "@src/utils/validate" +import { validateApiConfigurationExcludingModelErrors, getModelValidationError } from "@src/utils/validate" import { useAppTranslation } from "@src/i18n/TranslationContext" import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel" @@ -176,8 +186,11 @@ const ApiOptions = ({ ) useEffect(() => { - const apiValidationResult = validateApiConfiguration(apiConfiguration, routerModels, organizationAllowList) - + const apiValidationResult = validateApiConfigurationExcludingModelErrors( + apiConfiguration, + routerModels, + organizationAllowList, + ) setErrorMessage(apiValidationResult) }, [apiConfiguration, routerModels, organizationAllowList, setErrorMessage]) @@ -187,16 +200,20 @@ const ApiOptions = ({ const filteredModels = filterModels(models, selectedProvider, organizationAllowList) - return filteredModels + const modelOptions = filteredModels ? Object.keys(filteredModels).map((modelId) => ({ value: modelId, label: modelId, })) : [] + + return modelOptions }, [selectedProvider, organizationAllowList]) const onProviderChange = useCallback( (value: ProviderName) => { + setApiConfigurationField("apiProvider", value) + // It would be much easier to have a single attribute that stores // the modelId, but we have a separate attribute for each of // OpenRouter, Glama, Unbound, and Requesty. @@ -204,46 +221,69 @@ const ApiOptions = ({ // modelId is not set then you immediately end up in an error state. // To address that we set the modelId to the default value for th // provider if it's not already set. - switch (value) { - case "openrouter": - if (!apiConfiguration.openRouterModelId) { - setApiConfigurationField("openRouterModelId", openRouterDefaultModelId) - } - break - case "glama": - if (!apiConfiguration.glamaModelId) { - setApiConfigurationField("glamaModelId", glamaDefaultModelId) - } - break - case "unbound": - if (!apiConfiguration.unboundModelId) { - setApiConfigurationField("unboundModelId", unboundDefaultModelId) - } - break - case "requesty": - if (!apiConfiguration.requestyModelId) { - setApiConfigurationField("requestyModelId", requestyDefaultModelId) - } - break - case "litellm": - if (!apiConfiguration.litellmModelId) { - setApiConfigurationField("litellmModelId", litellmDefaultModelId) + const validateAndResetModel = ( + modelId: string | undefined, + field: keyof ProviderSettings, + defaultValue?: string, + ) => { + // in case we haven't set a default value for a provider + if (!defaultValue) return + + // only set default if no model is set, but don't reset invalid models + // let users see and decide what to do with invalid model selections + const shouldSetDefault = !modelId + + if (shouldSetDefault) { + setApiConfigurationField(field, defaultValue) + } + } + + // Define a mapping object that associates each provider with its model configuration + const PROVIDER_MODEL_CONFIG: Partial< + Record< + ProviderName, + { + field: keyof ProviderSettings + default?: string } - break + > + > = { + openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId }, + glama: { field: "glamaModelId", default: glamaDefaultModelId }, + unbound: { field: "unboundModelId", default: unboundDefaultModelId }, + requesty: { field: "requestyModelId", default: requestyDefaultModelId }, + litellm: { field: "litellmModelId", default: litellmDefaultModelId }, + anthropic: { field: "apiModelId", default: anthropicDefaultModelId }, + "openai-native": { field: "apiModelId", default: openAiNativeDefaultModelId }, + gemini: { field: "apiModelId", default: geminiDefaultModelId }, + deepseek: { field: "apiModelId", default: deepSeekDefaultModelId }, + mistral: { field: "apiModelId", default: mistralDefaultModelId }, + xai: { field: "apiModelId", default: xaiDefaultModelId }, + groq: { field: "apiModelId", default: groqDefaultModelId }, + chutes: { field: "apiModelId", default: chutesDefaultModelId }, + bedrock: { field: "apiModelId", default: bedrockDefaultModelId }, + vertex: { field: "apiModelId", default: vertexDefaultModelId }, + openai: { field: "openAiModelId" }, + ollama: { field: "ollamaModelId" }, + lmstudio: { field: "lmStudioModelId" }, } - setApiConfigurationField("apiProvider", value) + const config = PROVIDER_MODEL_CONFIG[value] + if (config) { + validateAndResetModel( + apiConfiguration[config.field] as string | undefined, + config.field, + config.default, + ) + } }, - [ - setApiConfigurationField, - apiConfiguration.openRouterModelId, - apiConfiguration.glamaModelId, - apiConfiguration.unboundModelId, - apiConfiguration.requestyModelId, - apiConfiguration.litellmModelId, - ], + [setApiConfigurationField, apiConfiguration], ) + const modelValidationError = useMemo(() => { + return getModelValidationError(apiConfiguration, routerModels, organizationAllowList) + }, [apiConfiguration, routerModels, organizationAllowList]) + const docs = useMemo(() => { const provider = PROVIDERS.find(({ value }) => value === selectedProvider) const name = provider?.label @@ -303,6 +343,7 @@ const ApiOptions = ({ uriScheme={uriScheme} fromWelcomeView={fromWelcomeView} organizationAllowList={organizationAllowList} + modelValidationError={modelValidationError} /> )} @@ -313,6 +354,7 @@ const ApiOptions = ({ routerModels={routerModels} refetchRouterModels={refetchRouterModels} organizationAllowList={organizationAllowList} + modelValidationError={modelValidationError} /> )} @@ -323,6 +365,7 @@ const ApiOptions = ({ routerModels={routerModels} uriScheme={uriScheme} organizationAllowList={organizationAllowList} + modelValidationError={modelValidationError} /> )} @@ -332,6 +375,7 @@ const ApiOptions = ({ setApiConfigurationField={setApiConfigurationField} routerModels={routerModels} organizationAllowList={organizationAllowList} + modelValidationError={modelValidationError} /> )} @@ -368,6 +412,7 @@ const ApiOptions = ({ apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} organizationAllowList={organizationAllowList} + modelValidationError={modelValidationError} /> )} @@ -404,6 +449,7 @@ const ApiOptions = ({ apiConfiguration={apiConfiguration} setApiConfigurationField={setApiConfigurationField} organizationAllowList={organizationAllowList} + modelValidationError={modelValidationError} /> )} diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 906b98e47e..bc962b921d 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -23,6 +23,7 @@ import { } from "@src/components/ui" import { ModelInfoView } from "./ModelInfoView" +import { ApiErrorMessage } from "./ApiErrorMessage" type ModelIdKey = keyof Pick< ProviderSettings, @@ -38,6 +39,7 @@ interface ModelPickerProps { apiConfiguration: ProviderSettings setApiConfigurationField: (field: K, value: ProviderSettings[K]) => void organizationAllowList: OrganizationAllowList + errorMessage?: string } export const ModelPicker = ({ @@ -49,6 +51,7 @@ export const ModelPicker = ({ apiConfiguration, setApiConfigurationField, organizationAllowList, + errorMessage, }: ModelPickerProps) => { const { t } = useAppTranslation() @@ -119,7 +122,8 @@ export const ModelPicker = ({ variant="combobox" role="combobox" aria-expanded={open} - className="w-full justify-between"> + className="w-full justify-between" + data-testid="model-picker-button">
{selectedModelId ?? t("settings:common.select")}
@@ -154,7 +158,11 @@ export const ModelPicker = ({ {modelIds.map((model) => ( - + {model}
+ {errorMessage && } {selectedModelId && selectedModelInfo && ( { await act(async () => { // Open the popover by clicking the button. - const button = screen.getByRole("combobox") + const button = screen.getByTestId("model-picker-button") fireEvent.click(button) }) @@ -91,7 +91,7 @@ describe("ModelPicker", () => { // Need to find and click the CommandItem to trigger onSelect await act(async () => { // Find the CommandItem for model2 and click it - const modelItem = screen.getByText("model2") + const modelItem = screen.getByTestId("model-option-model2") fireEvent.click(modelItem) }) @@ -104,7 +104,7 @@ describe("ModelPicker", () => { await act(async () => { // Open the popover by clicking the button. - const button = screen.getByRole("combobox") + const button = screen.getByTestId("model-picker-button") fireEvent.click(button) }) @@ -136,4 +136,111 @@ describe("ModelPicker", () => { // Verify the API config was updated with the custom model ID expect(mockSetApiConfigurationField).toHaveBeenCalledWith(defaultProps.modelIdKey, customModelId) }) + + describe("Error Message Display", () => { + it("displays error message when errorMessage prop is provided", async () => { + const errorMessage = "Model not available for your organization" + const propsWithError = { + ...defaultProps, + errorMessage, + } + + await act(async () => { + render( + + + , + ) + }) + + // Check that the error message is displayed + expect(screen.getByTestId("api-error-message")).toBeInTheDocument() + expect(screen.getByText(errorMessage)).toBeInTheDocument() + }) + + it("does not display error message when errorMessage prop is undefined", async () => { + await act(async () => renderModelPicker()) + + // Check that no error message is displayed + expect(screen.queryByTestId("api-error-message")).not.toBeInTheDocument() + }) + + it("displays error message below the model selector", async () => { + const errorMessage = "Invalid model selected" + const propsWithError = { + ...defaultProps, + errorMessage, + } + + await act(async () => { + render( + + + , + ) + }) + + // Check that both the model selector and error message are present + const modelSelector = screen.getByTestId("model-picker-button") + const errorContainer = screen.getByTestId("api-error-message") + const errorElement = screen.getByText(errorMessage) + + expect(modelSelector).toBeInTheDocument() + expect(errorContainer).toBeInTheDocument() + expect(errorElement).toBeInTheDocument() + expect(errorElement).toBeVisible() + }) + + it("updates error message when errorMessage prop changes", async () => { + const initialError = "Initial error" + const updatedError = "Updated error" + + const { rerender } = render( + + + , + ) + + // Check initial error is displayed + expect(screen.getByTestId("api-error-message")).toBeInTheDocument() + expect(screen.getByText(initialError)).toBeInTheDocument() + + // Update the error message + rerender( + + + , + ) + + // Check that the error message has been updated + expect(screen.getByTestId("api-error-message")).toBeInTheDocument() + expect(screen.queryByText(initialError)).not.toBeInTheDocument() + expect(screen.getByText(updatedError)).toBeInTheDocument() + }) + + it("removes error message when errorMessage prop becomes undefined", async () => { + const errorMessage = "Temporary error" + + const { rerender } = render( + + + , + ) + + // Check error is initially displayed + expect(screen.getByTestId("api-error-message")).toBeInTheDocument() + expect(screen.getByText(errorMessage)).toBeInTheDocument() + + // Remove the error message + rerender( + + + , + ) + + // Check that the error message has been removed + expect(screen.queryByTestId("api-error-message")).not.toBeInTheDocument() + expect(screen.queryByText(errorMessage)).not.toBeInTheDocument() + }) + }) }) diff --git a/webview-ui/src/components/settings/providers/Glama.tsx b/webview-ui/src/components/settings/providers/Glama.tsx index 85c218954a..ca1c6590ef 100644 --- a/webview-ui/src/components/settings/providers/Glama.tsx +++ b/webview-ui/src/components/settings/providers/Glama.tsx @@ -18,6 +18,7 @@ type GlamaProps = { routerModels?: RouterModels uriScheme?: string organizationAllowList: OrganizationAllowList + modelValidationError?: string } export const Glama = ({ @@ -26,6 +27,7 @@ export const Glama = ({ routerModels, uriScheme, organizationAllowList, + modelValidationError, }: GlamaProps) => { const { t } = useAppTranslation() @@ -67,6 +69,7 @@ export const Glama = ({ serviceName="Glama" serviceUrl="https://glama.ai/models" organizationAllowList={organizationAllowList} + errorMessage={modelValidationError} /> ) diff --git a/webview-ui/src/components/settings/providers/LiteLLM.tsx b/webview-ui/src/components/settings/providers/LiteLLM.tsx index 6da99e9892..a2467b3c0b 100644 --- a/webview-ui/src/components/settings/providers/LiteLLM.tsx +++ b/webview-ui/src/components/settings/providers/LiteLLM.tsx @@ -18,9 +18,15 @@ type LiteLLMProps = { apiConfiguration: ProviderSettings setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void organizationAllowList: OrganizationAllowList + modelValidationError?: string } -export const LiteLLM = ({ apiConfiguration, setApiConfigurationField, organizationAllowList }: LiteLLMProps) => { +export const LiteLLM = ({ + apiConfiguration, + setApiConfigurationField, + organizationAllowList, + modelValidationError, +}: LiteLLMProps) => { const { t } = useAppTranslation() const { routerModels } = useExtensionState() const [refreshStatus, setRefreshStatus] = useState<"idle" | "loading" | "success" | "error">("idle") @@ -143,6 +149,7 @@ export const LiteLLM = ({ apiConfiguration, setApiConfigurationField, organizati serviceUrl="https://docs.litellm.ai/" setApiConfigurationField={setApiConfigurationField} organizationAllowList={organizationAllowList} + errorMessage={modelValidationError} /> ) diff --git a/webview-ui/src/components/settings/providers/OpenAICompatible.tsx b/webview-ui/src/components/settings/providers/OpenAICompatible.tsx index 43fea540c3..9c2f051465 100644 --- a/webview-ui/src/components/settings/providers/OpenAICompatible.tsx +++ b/webview-ui/src/components/settings/providers/OpenAICompatible.tsx @@ -27,12 +27,14 @@ type OpenAICompatibleProps = { apiConfiguration: ProviderSettings setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void organizationAllowList: OrganizationAllowList + modelValidationError?: string } export const OpenAICompatible = ({ apiConfiguration, setApiConfigurationField, organizationAllowList, + modelValidationError, }: OpenAICompatibleProps) => { const { t } = useAppTranslation() @@ -144,6 +146,7 @@ export const OpenAICompatible = ({ serviceName="OpenAI" serviceUrl="https://platform.openai.com" organizationAllowList={organizationAllowList} + errorMessage={modelValidationError} /> { const { t } = useAppTranslation() @@ -135,6 +137,7 @@ export const OpenRouter = ({ serviceName="OpenRouter" serviceUrl="https://openrouter.ai/models" organizationAllowList={organizationAllowList} + errorMessage={modelValidationError} /> {openRouterModelProviders && Object.keys(openRouterModelProviders).length > 0 && (
diff --git a/webview-ui/src/components/settings/providers/Requesty.tsx b/webview-ui/src/components/settings/providers/Requesty.tsx index 617e401211..ac9e2735e9 100644 --- a/webview-ui/src/components/settings/providers/Requesty.tsx +++ b/webview-ui/src/components/settings/providers/Requesty.tsx @@ -20,6 +20,7 @@ type RequestyProps = { routerModels?: RouterModels refetchRouterModels: () => void organizationAllowList: OrganizationAllowList + modelValidationError?: string } export const Requesty = ({ @@ -28,6 +29,7 @@ export const Requesty = ({ routerModels, refetchRouterModels, organizationAllowList, + modelValidationError, }: RequestyProps) => { const { t } = useAppTranslation() @@ -96,6 +98,7 @@ export const Requesty = ({ serviceName="Requesty" serviceUrl="https://requesty.ai" organizationAllowList={organizationAllowList} + errorMessage={modelValidationError} /> ) diff --git a/webview-ui/src/components/settings/providers/Unbound.tsx b/webview-ui/src/components/settings/providers/Unbound.tsx index d0a862f20c..001ebf058e 100644 --- a/webview-ui/src/components/settings/providers/Unbound.tsx +++ b/webview-ui/src/components/settings/providers/Unbound.tsx @@ -19,6 +19,7 @@ type UnboundProps = { setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void routerModels?: RouterModels organizationAllowList: OrganizationAllowList + modelValidationError?: string } export const Unbound = ({ @@ -26,6 +27,7 @@ export const Unbound = ({ setApiConfigurationField, routerModels, organizationAllowList, + modelValidationError, }: UnboundProps) => { const { t } = useAppTranslation() const [didRefetch, setDidRefetch] = useState() @@ -176,6 +178,7 @@ export const Unbound = ({ serviceUrl="https://api.getunbound.ai/models" setApiConfigurationField={setApiConfigurationField} organizationAllowList={organizationAllowList} + errorMessage={modelValidationError} /> ) diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.test.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.test.ts index e7806a9f21..b2d069201e 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.test.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.test.ts @@ -284,18 +284,8 @@ describe("useSelectedModel", () => { const wrapper = createWrapper() const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) - expect(result.current.id).toBe("anthropic/claude-sonnet-4") - expect(result.current.info).toEqual({ - maxTokens: 8192, - contextWindow: 200_000, - supportsImages: true, - supportsComputerUse: true, - supportsPromptCache: true, - inputPrice: 3.0, - outputPrice: 15.0, - cacheWritesPrice: 3.75, - cacheReadsPrice: 0.3, - }) + expect(result.current.id).toBe("non-existent-model") + expect(result.current.info).toBeUndefined() }) }) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 9f77cbe370..72cee39e41 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -75,7 +75,10 @@ function getSelectedModel({ apiConfiguration: ProviderSettings routerModels: RouterModels openRouterModelProviders: Record -}): { id: string; info: ModelInfo } { +}): { id: string; info: ModelInfo | undefined } { + // the `undefined` case are used to show the invalid selection to prevent + // users from seeing the default model if their selection is invalid + // this gives a better UX than showing the default model switch (provider) { case "openrouter": { const id = apiConfiguration.openRouterModelId ?? openRouterDefaultModelId @@ -91,50 +94,42 @@ function getSelectedModel({ : openRouterModelProviders[specificProvider] } - return info - ? { id, info } - : { id: openRouterDefaultModelId, info: routerModels.openrouter[openRouterDefaultModelId] } + return { id, info } } case "requesty": { const id = apiConfiguration.requestyModelId ?? requestyDefaultModelId const info = routerModels.requesty[id] - return info - ? { id, info } - : { id: requestyDefaultModelId, info: routerModels.requesty[requestyDefaultModelId] } + return { id, info } } case "glama": { const id = apiConfiguration.glamaModelId ?? glamaDefaultModelId const info = routerModels.glama[id] - return info ? { id, info } : { id: glamaDefaultModelId, info: routerModels.glama[glamaDefaultModelId] } + return { id, info } } case "unbound": { const id = apiConfiguration.unboundModelId ?? unboundDefaultModelId const info = routerModels.unbound[id] - return info - ? { id, info } - : { id: unboundDefaultModelId, info: routerModels.unbound[unboundDefaultModelId] } + return { id, info } } case "litellm": { const id = apiConfiguration.litellmModelId ?? litellmDefaultModelId const info = routerModels.litellm[id] - return info - ? { id, info } - : { id: litellmDefaultModelId, info: routerModels.litellm[litellmDefaultModelId] } + return { id, info } } case "xai": { const id = apiConfiguration.apiModelId ?? xaiDefaultModelId const info = xaiModels[id as keyof typeof xaiModels] - return info ? { id, info } : { id: xaiDefaultModelId, info: xaiModels[xaiDefaultModelId] } + return info ? { id, info } : { id, info: undefined } } case "groq": { const id = apiConfiguration.apiModelId ?? groqDefaultModelId const info = groqModels[id as keyof typeof groqModels] - return info ? { id, info } : { id: groqDefaultModelId, info: groqModels[groqDefaultModelId] } + return { id, info } } case "chutes": { const id = apiConfiguration.apiModelId ?? chutesDefaultModelId const info = chutesModels[id as keyof typeof chutesModels] - return info ? { id, info } : { id: chutesDefaultModelId, info: chutesModels[chutesDefaultModelId] } + return { id, info } } case "bedrock": { const id = apiConfiguration.apiModelId ?? bedrockDefaultModelId @@ -148,34 +143,32 @@ function getSelectedModel({ } } - return info ? { id, info } : { id: bedrockDefaultModelId, info: bedrockModels[bedrockDefaultModelId] } + return { id, info } } case "vertex": { const id = apiConfiguration.apiModelId ?? vertexDefaultModelId const info = vertexModels[id as keyof typeof vertexModels] - return info ? { id, info } : { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] } + return { id, info } } case "gemini": { const id = apiConfiguration.apiModelId ?? geminiDefaultModelId const info = geminiModels[id as keyof typeof geminiModels] - return info ? { id, info } : { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] } + return { id, info } } case "deepseek": { const id = apiConfiguration.apiModelId ?? deepSeekDefaultModelId const info = deepSeekModels[id as keyof typeof deepSeekModels] - return info ? { id, info } : { id: deepSeekDefaultModelId, info: deepSeekModels[deepSeekDefaultModelId] } + return { id, info } } case "openai-native": { const id = apiConfiguration.apiModelId ?? openAiNativeDefaultModelId const info = openAiNativeModels[id as keyof typeof openAiNativeModels] - return info - ? { id, info } - : { id: openAiNativeDefaultModelId, info: openAiNativeModels[openAiNativeDefaultModelId] } + return { id, info } } case "mistral": { const id = apiConfiguration.apiModelId ?? mistralDefaultModelId const info = mistralModels[id as keyof typeof mistralModels] - return info ? { id, info } : { id: mistralDefaultModelId, info: mistralModels[mistralDefaultModelId] } + return { id, info } } case "openai": { const id = apiConfiguration.openAiModelId ?? "" @@ -206,7 +199,7 @@ function getSelectedModel({ default: { const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId const info = anthropicModels[id as keyof typeof anthropicModels] - return info ? { id, info } : { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] } + return { id, info } } } } diff --git a/webview-ui/src/utils/__tests__/validate.test.ts b/webview-ui/src/utils/__tests__/validate.test.ts new file mode 100644 index 0000000000..404b50e1dd --- /dev/null +++ b/webview-ui/src/utils/__tests__/validate.test.ts @@ -0,0 +1,187 @@ +import { ProviderSettings, OrganizationAllowList } from "@roo-code/types" +import { RouterModels } from "@roo/api" + +import { getModelValidationError, validateApiConfigurationExcludingModelErrors } from "../validate" + +describe("Model Validation Functions", () => { + const mockRouterModels: RouterModels = { + openrouter: { + "valid-model": { + maxTokens: 8192, + contextWindow: 200000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 3.0, + outputPrice: 15.0, + }, + "another-valid-model": { + maxTokens: 4096, + contextWindow: 100000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 1.0, + outputPrice: 5.0, + }, + }, + glama: { + "valid-model": { + maxTokens: 8192, + contextWindow: 200000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 3.0, + outputPrice: 15.0, + }, + }, + requesty: {}, + unbound: {}, + litellm: {}, + } + + const allowAllOrganization: OrganizationAllowList = { + allowAll: true, + providers: {}, + } + + const restrictiveOrganization: OrganizationAllowList = { + allowAll: false, + providers: { + openrouter: { + allowAll: false, + models: ["valid-model"], + }, + }, + } + + describe("getModelValidationError", () => { + it("returns undefined for valid OpenRouter model", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterModelId: "valid-model", + } + + const result = getModelValidationError(config, mockRouterModels, allowAllOrganization) + expect(result).toBeUndefined() + }) + + it("returns error for invalid OpenRouter model", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterModelId: "invalid-model", + } + + const result = getModelValidationError(config, mockRouterModels, allowAllOrganization) + expect(result).toBe("validation.modelAvailability") + }) + + it("returns error for model not allowed by organization", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterModelId: "another-valid-model", + } + + const result = getModelValidationError(config, mockRouterModels, restrictiveOrganization) + expect(result).toContain("model") + }) + + it("returns undefined for valid Glama model", () => { + const config: ProviderSettings = { + apiProvider: "glama", + glamaModelId: "valid-model", + } + + const result = getModelValidationError(config, mockRouterModels, allowAllOrganization) + expect(result).toBeUndefined() + }) + + it("returns error for invalid Glama model", () => { + const config: ProviderSettings = { + apiProvider: "glama", + glamaModelId: "invalid-model", + } + + const result = getModelValidationError(config, mockRouterModels, allowAllOrganization) + expect(result).toBeUndefined() + }) + + it("returns undefined for OpenAI models when no router models provided", () => { + const config: ProviderSettings = { + apiProvider: "openai", + openAiModelId: "gpt-4", + } + + const result = getModelValidationError(config, undefined, allowAllOrganization) + expect(result).toBeUndefined() + }) + + it("handles empty model IDs gracefully", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterModelId: "", + } + + const result = getModelValidationError(config, mockRouterModels, allowAllOrganization) + expect(result).toBe("validation.modelId") + }) + + it("handles undefined model IDs gracefully", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + // openRouterModelId is undefined + } + + const result = getModelValidationError(config, mockRouterModels, allowAllOrganization) + expect(result).toBe("validation.modelId") + }) + }) + + describe("validateApiConfigurationExcludingModelErrors", () => { + it("returns undefined when configuration is valid", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterApiKey: "valid-key", + openRouterModelId: "valid-model", + } + + const result = validateApiConfigurationExcludingModelErrors(config, mockRouterModels, allowAllOrganization) + expect(result).toBeUndefined() + }) + + it("returns error for missing API key", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterModelId: "valid-model", + // Missing openRouterApiKey + } + + const result = validateApiConfigurationExcludingModelErrors(config, mockRouterModels, allowAllOrganization) + expect(result).toBe("validation.apiKey") + }) + + it("excludes model-specific errors", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterApiKey: "valid-key", + openRouterModelId: "invalid-model", // This should be ignored + } + + const result = validateApiConfigurationExcludingModelErrors(config, mockRouterModels, allowAllOrganization) + expect(result).toBeUndefined() // Should not return model validation error + }) + + it("excludes model-specific organization errors", () => { + const config: ProviderSettings = { + apiProvider: "openrouter", + openRouterApiKey: "valid-key", + openRouterModelId: "another-valid-model", // Not allowed by restrictive org + } + + const result = validateApiConfigurationExcludingModelErrors( + config, + mockRouterModels, + restrictiveOrganization, + ) + expect(result).toBeUndefined() // Should exclude model-specific org errors + }) + }) +}) diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 5122ca58d4..2c1b21c256 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -14,12 +14,12 @@ export function validateApiConfiguration( return keysAndIdsPresentErrorMessage } - const organizationAllowListErrorMessage = validateProviderAgainstOrganizationSettings( + const organizationAllowListError = validateProviderAgainstOrganizationSettings( apiConfiguration, organizationAllowList, ) - if (organizationAllowListErrorMessage) { - return organizationAllowListErrorMessage + if (organizationAllowListError) { + return organizationAllowListError.message } return validateModelId(apiConfiguration, routerModels) @@ -107,17 +107,25 @@ function validateModelsAndKeysProvided(apiConfiguration: ProviderSettings): stri return undefined } +type ValidationError = { + message: string + code: 'PROVIDER_NOT_ALLOWED' | 'MODEL_NOT_ALLOWED' +} + function validateProviderAgainstOrganizationSettings( apiConfiguration: ProviderSettings, organizationAllowList?: OrganizationAllowList, -): string | undefined { +): ValidationError | undefined { if (organizationAllowList && !organizationAllowList.allowAll) { const provider = apiConfiguration.apiProvider if (!provider) return undefined const providerConfig = organizationAllowList.providers[provider] if (!providerConfig) { - return i18next.t("settings:validation.providerNotAllowed", { provider }) + return { + message: i18next.t("settings:validation.providerNotAllowed", { provider }), + code: 'PROVIDER_NOT_ALLOWED' + } } if (!providerConfig.allowAll) { @@ -125,10 +133,13 @@ function validateProviderAgainstOrganizationSettings( const allowedModels = providerConfig.models || [] if (modelId && !allowedModels.includes(modelId)) { - return i18next.t("settings:validation.modelNotAllowed", { - model: modelId, - provider, - }) + return { + message: i18next.t("settings:validation.modelNotAllowed", { + model: modelId, + provider, + }), + code: 'MODEL_NOT_ALLOWED' + } } } } @@ -233,3 +244,55 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels return undefined } + +/** + * Extracts model-specific validation errors from the API configuration + * This is used to show model errors specifically in the model selector components + */ +export function getModelValidationError( + apiConfiguration: ProviderSettings, + routerModels?: RouterModels, + organizationAllowList?: OrganizationAllowList, +): string | undefined { + const modelId = getModelIdForProvider(apiConfiguration, apiConfiguration.apiProvider || "") + const configWithModelId = { + ...apiConfiguration, + apiModelId: modelId || "", + } + + const orgError = validateProviderAgainstOrganizationSettings(configWithModelId, organizationAllowList) + if (orgError && orgError.code === 'MODEL_NOT_ALLOWED') { + return orgError.message + } + + return validateModelId(configWithModelId, routerModels) +} + +/** + * Validates API configuration but excludes model-specific errors + * This is used for the general API error display to prevent duplication + * when model errors are shown in the model selector + */ +export function validateApiConfigurationExcludingModelErrors( + apiConfiguration: ProviderSettings, + _routerModels?: RouterModels, // keeping this for compatibility with the old function + organizationAllowList?: OrganizationAllowList, +): string | undefined { + const keysAndIdsPresentErrorMessage = validateModelsAndKeysProvided(apiConfiguration) + if (keysAndIdsPresentErrorMessage) { + return keysAndIdsPresentErrorMessage + } + + const organizationAllowListError = validateProviderAgainstOrganizationSettings( + apiConfiguration, + organizationAllowList, + ) + + // only return organization errors if they're not model-specific + if (organizationAllowListError && organizationAllowListError.code === 'PROVIDER_NOT_ALLOWED') { + return organizationAllowListError.message + } + + // skip model validation errors as they'll be shown in the model selector + return undefined +}