Skip to content

Commit c823663

Browse files
Updates AI provider/model selection flow
1 parent c283896 commit c823663

File tree

6 files changed

+296
-30
lines changed

6 files changed

+296
-30
lines changed

src/plus/ai/aiProviderService.ts

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { CancellationToken, Disposable, Event, MessageItem, ProgressOptions } from 'vscode';
2-
import { env, EventEmitter, window } from 'vscode';
2+
import { env, EventEmitter, ThemeIcon, window } from 'vscode';
33
import type { AIPrimaryProviders, AIProviderAndModel, AIProviders, SupportedAIModels } from '../../constants.ai';
44
import { primaryAIProviders } from '../../constants.ai';
55
import type { AIGenerateDraftEventData, Source, TelemetryEvents } from '../../constants.telemetry';
@@ -11,7 +11,7 @@ import type { GitRevisionReference } from '../../git/models/reference';
1111
import type { Repository } from '../../git/models/repository';
1212
import { uncommitted, uncommittedStaged } from '../../git/models/revision';
1313
import { assertsCommitHasFullDetails } from '../../git/utils/commit.utils';
14-
import { showAIModelPicker } from '../../quickpicks/aiModelPicker';
14+
import { showAIModelPicker, showAIProviderPicker } from '../../quickpicks/aiPicker';
1515
import { configuration } from '../../system/-webview/configuration';
1616
import type { Storage } from '../../system/-webview/storage';
1717
import { supportedInVSCodeVersion } from '../../system/-webview/vscode';
@@ -22,6 +22,7 @@ import { lazy } from '../../system/lazy';
2222
import type { Deferred } from '../../system/promise';
2323
import { getSettledValue } from '../../system/promise';
2424
import type { ServerConnection } from '../gk/serverConnection';
25+
import { ensureAccountQuickPick } from '../gk/utils/-webview/acount.utils';
2526
import type { AIActionType, AIModel, AIModelDescriptor } from './models/model';
2627
import type { PromptTemplateContext } from './models/promptTemplates';
2728
import type { AIProvider, AIRequestResult } from './models/provider';
@@ -184,10 +185,48 @@ export class AIProviderService implements Disposable {
184185

185186
if (options?.silent) return undefined;
186187

187-
const pick = await showAIModelPicker(this.container, cfg);
188-
if (pick == null) return undefined;
188+
let chosenProvider: AIProviders | undefined = undefined;
189+
let chosenModel: AIModel | undefined = undefined;
190+
191+
if (!options?.force) {
192+
const vsCodeModels = await this.getModels('vscode');
193+
if (vsCodeModels.length !== 0) {
194+
chosenProvider = 'vscode';
195+
} else if ((await this.container.subscription.getSubscription()).account?.verified) {
196+
chosenProvider = 'gitkraken';
197+
const gitkrakenModels = await this.getModels('gitkraken');
198+
chosenModel = gitkrakenModels.find(m => m.default);
199+
}
200+
}
201+
202+
if (chosenProvider == null) {
203+
chosenProvider = (await showAIProviderPicker(this.container, cfg))?.provider;
204+
if (chosenProvider == null) return;
205+
if (
206+
(chosenProvider === 'gitkraken' ||
207+
(chosenProvider !== 'vscode' &&
208+
(await this.container.storage.getSecret(`gitlens.${chosenProvider}.key`)) == null)) &&
209+
!(await ensureAccountQuickPick(
210+
this.container,
211+
{
212+
label: 'Use AI-powered GitLens features like Generate Commit Message, Explain Commit, and more.',
213+
iconPath: new ThemeIcon('sparkle'),
214+
},
215+
source,
216+
))
217+
) {
218+
return;
219+
}
220+
}
189221

190-
const model = await this.getOrUpdateModel(pick.model);
222+
if (!(await this.ensureProviderConfigured(chosenProvider))) return;
223+
224+
if (chosenModel == null) {
225+
chosenModel = (await showAIModelPicker(this.container, chosenProvider, cfg))?.model;
226+
if (chosenModel == null) return;
227+
}
228+
229+
const model = await this.getOrUpdateModel(chosenModel);
191230

192231
this.container.telemetry.sendEvent(
193232
'ai/switchModel',
@@ -204,6 +243,24 @@ export class AIProviderService implements Disposable {
204243
return model;
205244
}
206245

246+
private async ensureProviderConfigured(providerId: AIProviders): Promise<boolean> {
247+
const key = await this.container.storage.getSecret(`gitlens.${providerId}.key`);
248+
if (key != null) return true;
249+
250+
if (this._provider != null && providerId === this._provider.id) return this._provider.ensureConfigured();
251+
const type = await _supportedProviderTypes.get(providerId)?.value;
252+
if (type == null) {
253+
return false;
254+
}
255+
256+
const p = new type(this.container, this.connection);
257+
try {
258+
return await p.ensureConfigured();
259+
} finally {
260+
p.dispose();
261+
}
262+
}
263+
207264
private getOrUpdateModel(model: AIModel): Promise<AIModel | undefined>;
208265
private getOrUpdateModel<T extends AIProviders>(providerId: T, modelId: string): Promise<AIModel | undefined>;
209266
private async getOrUpdateModel(
@@ -592,6 +649,14 @@ export class AIProviderService implements Disposable {
592649
return changes;
593650
}
594651

652+
async resetProvider(provider: AIProviders): Promise<void> {
653+
void env.clipboard.writeText((await this.container.storage.getSecret(`gitlens.${provider}.key`)) ?? '');
654+
void this.container.storage.deleteSecret(`gitlens.${provider}.key`);
655+
656+
void this.container.storage.delete(`confirm:ai:tos:${provider}`);
657+
void this.container.storage.deleteWorkspace(`confirm:ai:tos:${provider}`);
658+
}
659+
595660
async reset(all?: boolean): Promise<void> {
596661
let { _provider: provider } = this;
597662
if (provider == null) {
@@ -625,11 +690,7 @@ export class AIProviderService implements Disposable {
625690
}
626691

627692
if (provider != null && result === resetCurrent) {
628-
void env.clipboard.writeText((await this.container.storage.getSecret(`gitlens.${provider.id}.key`)) ?? '');
629-
void this.container.storage.deleteSecret(`gitlens.${provider.id}.key`);
630-
631-
void this.container.storage.delete(`confirm:ai:tos:${provider.id}`);
632-
void this.container.storage.deleteWorkspace(`confirm:ai:tos:${provider.id}`);
693+
void this.resetProvider(provider.id);
633694
} else if (result === resetAll) {
634695
const keys = [];
635696
for (const [providerId] of _supportedProviderTypes) {

src/plus/ai/models/provider.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export interface AIProvider<Provider extends AIProviders = AIProviders> extends
2727

2828
onDidChange?: Event<void>;
2929

30+
ensureConfigured(): Promise<boolean>;
3031
getModels(): Promise<readonly AIModel<Provider>[]>;
3132
getPromptTemplate(action: AIActionType, model: AIModel<Provider>): Promise<PromptTemplate | undefined>;
3233

src/plus/ai/openAICompatibleProvider.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ export abstract class OpenAICompatibleProvider<T extends AIProviders> implements
5959
});
6060
}
6161

62+
async ensureConfigured(): Promise<boolean> {
63+
return (await this.getApiKey()) != null;
64+
}
65+
6266
protected getHeaders<TAction extends AIActionType>(
6367
_action: TAction,
6468
_model: AIModel<T>,

src/plus/ai/vscodeProvider.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ export class VSCodeAIProvider implements AIProvider<typeof provider.id> {
5454
return models.map(getModelFromChatModel);
5555
}
5656

57+
async ensureConfigured(): Promise<boolean> {
58+
return (await this.getModels()).length !== 0;
59+
}
60+
5761
async getPromptTemplate(action: AIActionType, model: VSCodeAIModel): Promise<PromptTemplate | undefined> {
5862
return Promise.resolve(getLocalPromptTemplate(action, model));
5963
}

src/plus/gk/utils/-webview/acount.utils.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import type { QuickPickItem } from 'vscode';
12
import { window } from 'vscode';
23
import type { Source } from '../../../../constants.telemetry';
34
import type { Container } from '../../../../container';
5+
import { createQuickPickSeparator } from '../../../../quickpicks/items/common';
6+
import type { DirectiveQuickPickItem } from '../../../../quickpicks/items/directive';
7+
import { createDirectiveQuickPickItem, Directive } from '../../../../quickpicks/items/directive';
48

59
export async function ensureAccount(container: Container, title: string, source: Source): Promise<boolean> {
610
while (true) {
@@ -52,3 +56,54 @@ export async function ensureAccount(container: Container, title: string, source:
5256

5357
return true;
5458
}
59+
60+
export async function ensureAccountQuickPick(
61+
container: Container,
62+
descriptionItem: QuickPickItem,
63+
source?: Source,
64+
): Promise<boolean> {
65+
while (true) {
66+
const account = (await container.subscription.getSubscription()).account;
67+
if (account?.verified === true) break;
68+
const directives: DirectiveQuickPickItem[] = [
69+
createDirectiveQuickPickItem(Directive.Noop, undefined, descriptionItem),
70+
];
71+
let placeholder = 'Requires an account to continue';
72+
if (account?.verified === false) {
73+
directives.push(
74+
createDirectiveQuickPickItem(Directive.RequiresVerification, true),
75+
createQuickPickSeparator(),
76+
createDirectiveQuickPickItem(Directive.Cancel),
77+
);
78+
placeholder = 'You must verify your email before you can continue';
79+
} else {
80+
directives.push(
81+
createDirectiveQuickPickItem(Directive.SignIn, true),
82+
createQuickPickSeparator(),
83+
createDirectiveQuickPickItem(Directive.Cancel),
84+
);
85+
}
86+
87+
const result = await window.showQuickPick(directives, {
88+
placeHolder: placeholder,
89+
ignoreFocusOut: true,
90+
});
91+
92+
if (result == null) return false;
93+
if (result.directive === Directive.Noop) continue;
94+
if (result.directive === Directive.SignIn) {
95+
if (await container.subscription.loginOrSignUp(false, source)) {
96+
continue;
97+
}
98+
}
99+
if (result.directive === Directive.RequiresVerification) {
100+
if (await container.subscription.resendVerification(source)) {
101+
continue;
102+
}
103+
}
104+
105+
return false;
106+
}
107+
108+
return true;
109+
}

0 commit comments

Comments
 (0)