Skip to content

Commit dd78e39

Browse files
committed
Improves AI provider/model fallback handling
1 parent cb17782 commit dd78e39

File tree

3 files changed

+81
-43
lines changed

3 files changed

+81
-43
lines changed

src/plus/ai/aiProviderService.ts

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -327,57 +327,92 @@ export class AIProviderService implements Disposable {
327327
return modelResults.flatMap(m => getSettledValue(m, []));
328328
}
329329

330+
private async getBestFallbackModel(): Promise<AIModel | undefined> {
331+
let model: AIModel | undefined;
332+
let models: readonly AIModel[];
333+
334+
const orgAIConfig = getOrgAIConfig();
335+
// First, use Copilot GPT 4.1 or first model
336+
if (isProviderEnabledByOrg('vscode', orgAIConfig)) {
337+
try {
338+
models = await this.getModels('vscode');
339+
if (models.length) {
340+
model = models.find(m => m.id === 'copilot:gpt-4.1') ?? models[0];
341+
if (model != null) return model;
342+
}
343+
} catch {}
344+
}
345+
346+
// Second, use the GitKraken AI default or first model
347+
if (isProviderEnabledByOrg('gitkraken', orgAIConfig)) {
348+
try {
349+
const subscription = await this.container.subscription.getSubscription();
350+
if (subscription.account?.verified) {
351+
models = await this.getModels('gitkraken');
352+
353+
model = models.find(m => m.default) ?? models[0];
354+
if (model != null) return model;
355+
}
356+
} catch {}
357+
}
358+
359+
return model;
360+
}
361+
330362
async getModel(options?: { force?: boolean; silent?: boolean }, source?: Source): Promise<AIModel | undefined> {
331363
const cfg = this.getConfiguredModel();
332364
if (!options?.force && cfg?.provider != null && cfg?.model != null) {
333365
const model = await this.getOrUpdateModel(cfg.provider, cfg.model);
334366
if (model != null) return model;
335367
}
336368

337-
if (options?.silent) return undefined;
338-
339-
let chosenProviderId: AIProviders | undefined;
340369
let chosenModel: AIModel | undefined;
341-
const orgAiConf = getOrgAIConfig();
342-
343-
if (!options?.force) {
344-
const vsCodeModels = await this.getModels('vscode');
345-
if (isProviderEnabledByOrg('vscode', orgAiConf) && vsCodeModels.length !== 0) {
346-
chosenProviderId = 'vscode';
347-
} else if (
348-
isProviderEnabledByOrg('gitkraken', orgAiConf) &&
349-
(await this.container.subscription.getSubscription()).account?.verified
350-
) {
351-
chosenProviderId = 'gitkraken';
352-
const gitkrakenModels = await this.getModels('gitkraken');
353-
chosenModel = gitkrakenModels.find(m => m.default);
370+
let chosenProviderId: AIProviders | undefined;
371+
const fallbackModel = lazy(() => this.getBestFallbackModel());
372+
373+
if (!options?.silent) {
374+
if (!options?.force) {
375+
chosenModel = await fallbackModel.value;
376+
chosenProviderId = chosenModel?.provider.id;
354377
}
355-
}
356378

357-
while (true) {
358-
chosenProviderId ??= (await showAIProviderPicker(this.container, cfg))?.provider;
359-
if (chosenProviderId == null) return;
379+
while (true) {
380+
chosenProviderId ??= (await showAIProviderPicker(this.container, cfg))?.provider;
381+
if (chosenProviderId == null) {
382+
chosenModel = undefined;
383+
break;
384+
}
385+
386+
const provider = supportedAIProviders.get(chosenProviderId);
387+
if (provider == null) {
388+
chosenModel = undefined;
389+
break;
390+
}
360391

361-
const provider = supportedAIProviders.get(chosenProviderId);
362-
if (provider == null) return;
392+
if (!(await this.ensureProviderConfigured(provider, false))) {
393+
chosenModel = undefined;
394+
}
363395

364-
if (!(await this.ensureProviderConfigured(provider, false))) return;
396+
if (chosenModel == null) {
397+
const result = await showAIModelPicker(this.container, chosenProviderId, cfg);
398+
if (result == null || (isDirective(result) && result !== Directive.Back)) {
399+
chosenModel = undefined;
400+
break;
401+
}
402+
if (result === Directive.Back) {
403+
chosenProviderId = undefined;
404+
continue;
405+
}
365406

366-
if (chosenModel == null) {
367-
const result = await showAIModelPicker(this.container, chosenProviderId, cfg);
368-
if (result == null || (isDirective(result) && result !== Directive.Back)) return;
369-
if (result === Directive.Back) {
370-
chosenProviderId = undefined;
371-
continue;
407+
chosenModel = result.model;
372408
}
373409

374-
chosenModel = result.model;
410+
break;
375411
}
376-
377-
break;
378412
}
379413

380-
const model = await this.getOrUpdateModel(chosenModel);
414+
chosenModel ??= await fallbackModel.value;
415+
const model = chosenModel == null ? undefined : await this.getOrUpdateModel(chosenModel);
381416

382417
this.container.telemetry.sendEvent(
383418
'ai/switchModel',
@@ -391,7 +426,9 @@ export class AIProviderService implements Disposable {
391426
source,
392427
);
393428

394-
void (await showConfirmAIProviderToS(this.container.storage));
429+
if (model != null) {
430+
void (await showConfirmAIProviderToS(this.container.storage));
431+
}
395432
return model;
396433
}
397434

src/plus/ai/utils/-webview/ai.utils.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,15 @@ export function getOrgAIConfig(): OrgAIConfig {
189189
};
190190
}
191191

192-
export function getOrgAIProviderOfType(type: AIProviders, orgAiConfig?: OrgAIConfig): OrgAIProvider {
193-
orgAiConfig ??= getOrgAIConfig();
194-
if (!orgAiConfig.aiEnabled) return { type: type, enabled: false };
195-
if (!orgAiConfig.enforceAiProviders) return { type: type, enabled: true };
196-
return orgAiConfig.aiProviders[type] ?? { type: type, enabled: false };
192+
export function getOrgAIProviderOfType(type: AIProviders, orgAIConfig?: OrgAIConfig): OrgAIProvider {
193+
orgAIConfig ??= getOrgAIConfig();
194+
if (!orgAIConfig.aiEnabled) return { type: type, enabled: false };
195+
if (!orgAIConfig.enforceAiProviders) return { type: type, enabled: true };
196+
return orgAIConfig.aiProviders[type] ?? { type: type, enabled: false };
197197
}
198198

199-
export function isProviderEnabledByOrg(type: AIProviders, orgAiConfig?: OrgAIConfig): boolean {
200-
return getOrgAIProviderOfType(type, orgAiConfig).enabled;
199+
export function isProviderEnabledByOrg(type: AIProviders, orgAIConfig?: OrgAIConfig): boolean {
200+
return getOrgAIProviderOfType(type, orgAIConfig).enabled;
201201
}
202202

203203
/**

src/quickpicks/aiModelPicker.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ export async function showAIModelPicker(
137137

138138
const items: Array<ModelQuickPickItem | DirectiveQuickPickItem> = [];
139139

140-
if (models.length === 0 && provider === 'ollama') {
140+
if (!models.length) {
141141
items.push({
142142
label: 'No models found',
143-
description: 'Please install a model or check your Ollama server configuration',
143+
description:
144+
provider === 'ollama' ? 'Please install a model or check your Ollama server configuration' : undefined,
144145
iconPath: new ThemeIcon('error'),
145146
directive: Directive.Noop,
146147
} satisfies ModelQuickPickItem | DirectiveQuickPickItem);

0 commit comments

Comments
 (0)