@@ -11,10 +11,20 @@ import {
1111 glamaDefaultModelId ,
1212 unboundDefaultModelId ,
1313 litellmDefaultModelId ,
14+ openAiNativeDefaultModelId ,
15+ anthropicDefaultModelId ,
16+ geminiDefaultModelId ,
17+ deepSeekDefaultModelId ,
18+ mistralDefaultModelId ,
19+ xaiDefaultModelId ,
20+ groqDefaultModelId ,
21+ chutesDefaultModelId ,
22+ bedrockDefaultModelId ,
23+ vertexDefaultModelId ,
1424} from "@roo-code/types"
1525
1626import { vscode } from "@src/utils/vscode"
17- import { validateApiConfiguration } from "@src/utils/validate"
27+ import { validateApiConfigurationExcludingModelErrors , getModelValidationError } from "@src/utils/validate"
1828import { useAppTranslation } from "@src/i18n/TranslationContext"
1929import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
2030import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel"
@@ -176,8 +186,11 @@ const ApiOptions = ({
176186 )
177187
178188 useEffect ( ( ) => {
179- const apiValidationResult = validateApiConfiguration ( apiConfiguration , routerModels , organizationAllowList )
180-
189+ const apiValidationResult = validateApiConfigurationExcludingModelErrors (
190+ apiConfiguration ,
191+ routerModels ,
192+ organizationAllowList ,
193+ )
181194 setErrorMessage ( apiValidationResult )
182195 } , [ apiConfiguration , routerModels , organizationAllowList , setErrorMessage ] )
183196
@@ -187,63 +200,90 @@ const ApiOptions = ({
187200
188201 const filteredModels = filterModels ( models , selectedProvider , organizationAllowList )
189202
190- return filteredModels
203+ const modelOptions = filteredModels
191204 ? Object . keys ( filteredModels ) . map ( ( modelId ) => ( {
192205 value : modelId ,
193206 label : modelId ,
194207 } ) )
195208 : [ ]
209+
210+ return modelOptions
196211 } , [ selectedProvider , organizationAllowList ] )
197212
198213 const onProviderChange = useCallback (
199214 ( value : ProviderName ) => {
215+ setApiConfigurationField ( "apiProvider" , value )
216+
200217 // It would be much easier to have a single attribute that stores
201218 // the modelId, but we have a separate attribute for each of
202219 // OpenRouter, Glama, Unbound, and Requesty.
203220 // If you switch to one of these providers and the corresponding
204221 // modelId is not set then you immediately end up in an error state.
205222 // To address that we set the modelId to the default value for th
206223 // provider if it's not already set.
207- switch ( value ) {
208- case "openrouter" :
209- if ( ! apiConfiguration . openRouterModelId ) {
210- setApiConfigurationField ( "openRouterModelId" , openRouterDefaultModelId )
211- }
212- break
213- case "glama" :
214- if ( ! apiConfiguration . glamaModelId ) {
215- setApiConfigurationField ( "glamaModelId" , glamaDefaultModelId )
216- }
217- break
218- case "unbound" :
219- if ( ! apiConfiguration . unboundModelId ) {
220- setApiConfigurationField ( "unboundModelId" , unboundDefaultModelId )
221- }
222- break
223- case "requesty" :
224- if ( ! apiConfiguration . requestyModelId ) {
225- setApiConfigurationField ( "requestyModelId" , requestyDefaultModelId )
226- }
227- break
228- case "litellm" :
229- if ( ! apiConfiguration . litellmModelId ) {
230- setApiConfigurationField ( "litellmModelId" , litellmDefaultModelId )
224+ const validateAndResetModel = (
225+ modelId : string | undefined ,
226+ field : keyof ProviderSettings ,
227+ defaultValue ?: string ,
228+ ) => {
229+ // in case we haven't set a default value for a provider
230+ if ( ! defaultValue ) return
231+
232+ // only set default if no model is set, but don't reset invalid models
233+ // let users see and decide what to do with invalid model selections
234+ const shouldSetDefault = ! modelId
235+
236+ if ( shouldSetDefault ) {
237+ setApiConfigurationField ( field , defaultValue )
238+ }
239+ }
240+
241+ // Define a mapping object that associates each provider with its model configuration
242+ const PROVIDER_MODEL_CONFIG : Partial <
243+ Record <
244+ ProviderName ,
245+ {
246+ field : keyof ProviderSettings
247+ default ?: string
231248 }
232- break
249+ >
250+ > = {
251+ openrouter : { field : "openRouterModelId" , default : openRouterDefaultModelId } ,
252+ glama : { field : "glamaModelId" , default : glamaDefaultModelId } ,
253+ unbound : { field : "unboundModelId" , default : unboundDefaultModelId } ,
254+ requesty : { field : "requestyModelId" , default : requestyDefaultModelId } ,
255+ litellm : { field : "litellmModelId" , default : litellmDefaultModelId } ,
256+ anthropic : { field : "apiModelId" , default : anthropicDefaultModelId } ,
257+ "openai-native" : { field : "apiModelId" , default : openAiNativeDefaultModelId } ,
258+ gemini : { field : "apiModelId" , default : geminiDefaultModelId } ,
259+ deepseek : { field : "apiModelId" , default : deepSeekDefaultModelId } ,
260+ mistral : { field : "apiModelId" , default : mistralDefaultModelId } ,
261+ xai : { field : "apiModelId" , default : xaiDefaultModelId } ,
262+ groq : { field : "apiModelId" , default : groqDefaultModelId } ,
263+ chutes : { field : "apiModelId" , default : chutesDefaultModelId } ,
264+ bedrock : { field : "apiModelId" , default : bedrockDefaultModelId } ,
265+ vertex : { field : "apiModelId" , default : vertexDefaultModelId } ,
266+ openai : { field : "openAiModelId" } ,
267+ ollama : { field : "ollamaModelId" } ,
268+ lmstudio : { field : "lmStudioModelId" } ,
233269 }
234270
235- setApiConfigurationField ( "apiProvider" , value )
271+ const config = PROVIDER_MODEL_CONFIG [ value ]
272+ if ( config ) {
273+ validateAndResetModel (
274+ apiConfiguration [ config . field ] as string | undefined ,
275+ config . field ,
276+ config . default ,
277+ )
278+ }
236279 } ,
237- [
238- setApiConfigurationField ,
239- apiConfiguration . openRouterModelId ,
240- apiConfiguration . glamaModelId ,
241- apiConfiguration . unboundModelId ,
242- apiConfiguration . requestyModelId ,
243- apiConfiguration . litellmModelId ,
244- ] ,
280+ [ setApiConfigurationField , apiConfiguration ] ,
245281 )
246282
283+ const modelValidationError = useMemo ( ( ) => {
284+ return getModelValidationError ( apiConfiguration , routerModels , organizationAllowList )
285+ } , [ apiConfiguration , routerModels , organizationAllowList ] )
286+
247287 const docs = useMemo ( ( ) => {
248288 const provider = PROVIDERS . find ( ( { value } ) => value === selectedProvider )
249289 const name = provider ?. label
@@ -303,6 +343,7 @@ const ApiOptions = ({
303343 uriScheme = { uriScheme }
304344 fromWelcomeView = { fromWelcomeView }
305345 organizationAllowList = { organizationAllowList }
346+ modelValidationError = { modelValidationError }
306347 />
307348 ) }
308349
@@ -313,6 +354,7 @@ const ApiOptions = ({
313354 routerModels = { routerModels }
314355 refetchRouterModels = { refetchRouterModels }
315356 organizationAllowList = { organizationAllowList }
357+ modelValidationError = { modelValidationError }
316358 />
317359 ) }
318360
@@ -323,6 +365,7 @@ const ApiOptions = ({
323365 routerModels = { routerModels }
324366 uriScheme = { uriScheme }
325367 organizationAllowList = { organizationAllowList }
368+ modelValidationError = { modelValidationError }
326369 />
327370 ) }
328371
@@ -332,6 +375,7 @@ const ApiOptions = ({
332375 setApiConfigurationField = { setApiConfigurationField }
333376 routerModels = { routerModels }
334377 organizationAllowList = { organizationAllowList }
378+ modelValidationError = { modelValidationError }
335379 />
336380 ) }
337381
@@ -368,6 +412,7 @@ const ApiOptions = ({
368412 apiConfiguration = { apiConfiguration }
369413 setApiConfigurationField = { setApiConfigurationField }
370414 organizationAllowList = { organizationAllowList }
415+ modelValidationError = { modelValidationError }
371416 />
372417 ) }
373418
@@ -404,6 +449,7 @@ const ApiOptions = ({
404449 apiConfiguration = { apiConfiguration }
405450 setApiConfigurationField = { setApiConfigurationField }
406451 organizationAllowList = { organizationAllowList }
452+ modelValidationError = { modelValidationError }
407453 />
408454 ) }
409455
0 commit comments