diff --git a/contributions.json b/contributions.json index a9c01185f9ec6..c4baed1c8c102 100644 --- a/contributions.json +++ b/contributions.json @@ -4570,7 +4570,7 @@ } }, "gitlens.switchAIModel": { - "label": "Switch AI Model", + "label": "Switch AI Provider/Model", "commandPalette": "gitlens:enabled && gitlens:gk:organization:ai:enabled" }, "gitlens.switchMode": { diff --git a/docs/telemetry-events.md b/docs/telemetry-events.md index 9a7b3b7e8b5ce..db68909d4939a 100644 --- a/docs/telemetry-events.md +++ b/docs/telemetry-events.md @@ -1484,7 +1484,7 @@ void 'repoPrivacy': 'private' | 'public' | 'local', 'repository.visibility': 'private' | 'public' | 'local', // Provided for compatibility with other GK surfaces - 'source': 'account' | 'subscription' | 'graph' | 'patchDetails' | 'settings' | 'timeline' | 'home' | 'view' | 'code-suggest' | 'associateIssueWithBranch' | 'cloud-patches' | 'commandPalette' | 'deeplink' | 'inspect' | 'inspect-overview' | 'integrations' | 'launchpad' | 'launchpad-indicator' | 'launchpad-view' | 'notification' | 'prompt' | 'quick-wizard' | 'remoteProvider' | 'startWork' | 'trial-indicator' | 'scm-input' | 'walkthrough' | 'whatsnew' | 'worktrees' + 'source': 'account' | 'subscription' | 'graph' | 'patchDetails' | 'settings' | 'timeline' | 'home' | 'view' | 'code-suggest' | 'ai' | 'ai:picker' | 'associateIssueWithBranch' | 'cloud-patches' | 'commandPalette' | 'deeplink' | 'inspect' | 'inspect-overview' | 'integrations' | 'launchpad' | 'launchpad-indicator' | 'launchpad-view' | 'notification' | 'prompt' | 'quick-wizard' | 'remoteProvider' | 'startWork' | 'trial-indicator' | 'scm-input' | 'walkthrough' | 'whatsnew' | 'worktrees' } ``` diff --git a/package.json b/package.json index 029519dae8353..e62d903c9173d 100644 --- a/package.json +++ b/package.json @@ -7705,7 +7705,7 @@ }, { "command": "gitlens.switchAIModel", - "title": "Switch AI Model", + "title": "Switch AI Provider/Model", "category": "GitLens" }, { diff --git a/src/commands/resets.ts b/src/commands/resets.ts index bedb95f8edc35..aefc1f8a6cfe4 100644 --- a/src/commands/resets.ts +++ b/src/commands/resets.ts @@ -10,6 +10,7 @@ import { GlCommandBase } from './commandBase'; const resetTypes = [ 'ai', + 'ai:confirmations', 'avatars', 'integrations', 'previews', @@ -35,6 +36,11 @@ export class ResetCommand extends GlCommandBase { detail: 'Clears any locally stored AI keys', item: 'ai', }, + { + label: 'AI Confirmations...', + detail: 'Clears any accepted AI confirmations', + item: 'ai:confirmations', + }, { label: 'Avatars...', detail: 'Clears the stored avatar cache', @@ -111,6 +117,10 @@ export class ResetCommand extends GlCommandBase { confirmationMessage = 'Are you sure you want to reset all of the stored AI keys?'; confirm.title = 'Reset AI Keys'; break; + case 'ai:confirmations': + confirmationMessage = 'Are you sure you want to reset all AI confirmations?'; + confirm.title = 'Reset AI Confirmations'; + break; case 'avatars': confirmationMessage = 'Are you sure you want to reset the avatar cache?'; confirm.title = 'Reset Avatars'; @@ -172,6 +182,10 @@ export class ResetCommand extends GlCommandBase { await this.container.ai.reset(true); break; + case 'ai:confirmations': + this.container.ai.resetConfirmations(); + break; + case 'avatars': resetAvatarCache('all'); break; diff --git a/src/constants.ai.ts b/src/constants.ai.ts index da44313b6eabb..f971f6470474a 100644 --- a/src/constants.ai.ts +++ b/src/constants.ai.ts @@ -1,3 +1,5 @@ +import type { AIProviderDescriptor } from './plus/ai/models/model'; + export type AIProviders = | 'anthropic' | 'deepseek' @@ -9,7 +11,70 @@ export type AIProviders = | 'vscode' | 'xai'; export type AIPrimaryProviders = Extract; -export const primaryAIProviders = ['gitkraken', 'vscode'] as const satisfies readonly AIPrimaryProviders[]; export type AIProviderAndModel = `${string}:${string}`; export type SupportedAIModels = `${Exclude}:${string}` | AIPrimaryProviders; + +export const gitKrakenProviderDescriptor: AIProviderDescriptor<'gitkraken'> = { + id: 'gitkraken', + name: 'GitKraken AI (Preview)', + primary: true, + requiresAccount: true, + requiresUserKey: false, +} as const; +export const vscodeProviderDescriptor: AIProviderDescriptor<'vscode'> = { + id: 'vscode', + name: 'Copilot', + primary: true, + requiresAccount: false, + requiresUserKey: false, +} as const; +export const openAIProviderDescriptor: AIProviderDescriptor<'openai'> = { + id: 'openai', + name: 'OpenAI', + primary: false, + requiresAccount: true, + requiresUserKey: true, +} as const; +export const anthropicProviderDescriptor: AIProviderDescriptor<'anthropic'> = { + id: 'anthropic', + name: 'Anthropic', + primary: false, + requiresAccount: true, + requiresUserKey: true, +} as const; +export const geminiProviderDescriptor: AIProviderDescriptor<'gemini'> = { + id: 'gemini', + name: 'Google', + primary: false, + requiresAccount: true, + requiresUserKey: true, +} as const; +export const deepSeekProviderDescriptor: AIProviderDescriptor<'deepseek'> = { + id: 'deepseek', + name: 'DeepSeek', + primary: false, + requiresAccount: true, + requiresUserKey: true, +} as const; +export const xAIProviderDescriptor: AIProviderDescriptor<'xai'> = { + id: 'xai', + name: 'xAI', + primary: false, + requiresAccount: true, + requiresUserKey: true, +} as const; +export const githubProviderDescriptor: AIProviderDescriptor<'github'> = { + id: 'github', + name: 'GitHub Models', + primary: false, + requiresAccount: true, + requiresUserKey: true, +} as const; +export const huggingFaceProviderDescriptor: AIProviderDescriptor<'huggingface'> = { + id: 'huggingface', + name: 'Hugging Face', + primary: false, + requiresAccount: true, + requiresUserKey: true, +} as const; diff --git a/src/constants.storage.ts b/src/constants.storage.ts index b609485b0d5fe..8b2909318b546 100644 --- a/src/constants.storage.ts +++ b/src/constants.storage.ts @@ -25,7 +25,7 @@ export const enum SyncedStorageKeys { } export type DeprecatedGlobalStorage = { - /** @deprecated use `confirm:ai:tos:${AIProviders}` */ + /** @deprecated use `confirm:ai:tos` */ 'confirm:sendToOpenAI': boolean; /** @deprecated */ 'home:actions:completed': ('dismissed:welcome' | 'opened:scm')[]; @@ -54,10 +54,14 @@ export type DeprecatedGlobalStorage = { } & { /** @deprecated */ [key in `disallow:connection:${string}`]: any; +} & { + /** @deprecated use `confirm:ai:tos` */ + [key in `confirm:ai:tos:${AIProviders}`]: boolean; }; export type GlobalStorage = { avatars: [string, StoredAvatar][]; + 'confirm:ai:tos': boolean; repoVisibility: [string, StoredRepoVisibilityInfo][]; 'deepLinks:pending': StoredDeepLinkContext; pendingWhatsNewOnFocus: boolean; @@ -82,8 +86,6 @@ export type GlobalStorage = { 'views:scm:grouped:welcome:dismissed': boolean; 'integrations:configured': StoredIntegrationConfigurations; } & { [key in `plus:preview:${FeaturePreviews}:usages`]: StoredFeaturePreviewUsagePeriod[] } & { - [key in `confirm:ai:tos:${AIProviders}`]: boolean; -} & { [key in `provider:authentication:skip:${string}`]: boolean; } & { [key in `gk:${string}:checkin`]: Stored } & { [key in `gk:${string}:organizations`]: Stored; @@ -122,17 +124,21 @@ export interface StoredPromo { } export type DeprecatedWorkspaceStorage = { - /** @deprecated use `confirm:ai:tos:${AIProviders}` */ + /** @deprecated use `confirm:ai:tos` */ 'confirm:sendToOpenAI': boolean; /** @deprecated */ 'graph:banners:dismissed': Record; /** @deprecated */ 'views:searchAndCompare:keepResults': boolean; +} & { + /** @deprecated use `confirm:ai:tos` */ + [key in `confirm:ai:tos:${AIProviders}`]: boolean; }; export type WorkspaceStorage = { assumeRepositoriesOnStartup?: boolean; 'branch:comparisons': StoredBranchComparisons; + 'confirm:ai:tos': boolean; 'gitComandPalette:usage': StoredRecentUsage; gitPath: string; 'graph:columns': Record; @@ -145,7 +151,7 @@ export type WorkspaceStorage = { 'views:repositories:autoRefresh': boolean; 'views:searchAndCompare:pinned': StoredSearchAndCompareItems; 'views:scm:grouped:selected': GroupableTreeViewTypes; -} & { [key in `confirm:ai:tos:${AIProviders}`]: boolean } & { +} & { [key in `connected:${Integration['key']}`]: boolean; }; diff --git a/src/constants.telemetry.ts b/src/constants.telemetry.ts index 73744716e88b1..25032103d63c9 100644 --- a/src/constants.telemetry.ts +++ b/src/constants.telemetry.ts @@ -926,6 +926,8 @@ export type TrackingContext = 'graph' | 'launchpad' | 'visual_file_history' | 'w export type Sources = | 'account' + | 'ai' + | 'ai:picker' | 'associateIssueWithBranch' | 'code-suggest' | 'cloud-patches' diff --git a/src/plus/ai/aiProviderService.ts b/src/plus/ai/aiProviderService.ts index f4b85b4b5fff4..fe0540d5f6789 100644 --- a/src/plus/ai/aiProviderService.ts +++ b/src/plus/ai/aiProviderService.ts @@ -1,7 +1,17 @@ import type { CancellationToken, Disposable, Event, MessageItem, ProgressOptions } from 'vscode'; import { env, EventEmitter, window } from 'vscode'; import type { AIPrimaryProviders, AIProviderAndModel, AIProviders, SupportedAIModels } from '../../constants.ai'; -import { primaryAIProviders } from '../../constants.ai'; +import { + anthropicProviderDescriptor, + deepSeekProviderDescriptor, + geminiProviderDescriptor, + githubProviderDescriptor, + gitKrakenProviderDescriptor, + huggingFaceProviderDescriptor, + openAIProviderDescriptor, + vscodeProviderDescriptor, + xAIProviderDescriptor, +} from '../../constants.ai'; import type { AIGenerateDraftEventData, Source, TelemetryEvents } from '../../constants.telemetry'; import type { Container } from '../../container'; import { CancellationError } from '../../errors'; @@ -13,20 +23,27 @@ import type { GitRevisionReference } from '../../git/models/reference'; import type { Repository } from '../../git/models/repository'; import { uncommitted, uncommittedStaged } from '../../git/models/revision'; import { assertsCommitHasFullDetails } from '../../git/utils/commit.utils'; -import { showAIModelPicker } from '../../quickpicks/aiModelPicker'; +import { showAIModelPicker, showAIProviderPicker } from '../../quickpicks/aiModelPicker'; +import { Directive, isDirective } from '../../quickpicks/items/directive'; import { configuration } from '../../system/-webview/configuration'; import { getContext } from '../../system/-webview/context'; import type { Storage } from '../../system/-webview/storage'; -import { supportedInVSCodeVersion } from '../../system/-webview/vscode'; import { debounce } from '../../system/function/debounce'; import { map } from '../../system/iterable'; import type { Lazy } from '../../system/lazy'; import { lazy } from '../../system/lazy'; import type { Deferred } from '../../system/promise'; -import { getSettledValue } from '../../system/promise'; +import { getSettledValue, getSettledValues } from '../../system/promise'; import type { ServerConnection } from '../gk/serverConnection'; import { ensureFeatureAccess } from '../gk/utils/-webview/acount.utils'; -import type { AIActionType, AIModel, AIModelDescriptor } from './models/model'; +import type { + AIActionType, + AIModel, + AIModelDescriptor, + AIProviderConstructor, + AIProviderDescriptorWithConfiguration, + AIProviderDescriptorWithType, +} from './models/model'; import type { PromptTemplateContext } from './models/promptTemplates'; import type { AIProvider, AIRequestResult } from './models/provider'; @@ -58,53 +75,90 @@ export interface AIGenerateChangelogChange { readonly issues: readonly { readonly id: string; readonly url: string; readonly title: string | undefined }[]; } -interface AIProviderConstructor { - new (container: Container, connection: ServerConnection): AIProvider; +export interface AIModelChangeEvent { + readonly model: AIModel | undefined; } // Order matters for sorting the picker -const _supportedProviderTypes = new Map>>([ +const supportedAIProviders = new Map([ ...(configuration.getAny('gitkraken.ai.enabled', undefined, false) ? [ [ 'gitkraken', - lazy( - async () => - (await import(/* webpackChunkName: "ai" */ './gitkrakenProvider')).GitKrakenProvider, - ), - ], - ] - : []), - ...(supportedInVSCodeVersion('language-models') - ? [ - [ - 'vscode', - lazy(async () => (await import(/* webpackChunkName: "ai" */ './vscodeProvider')).VSCodeAIProvider), + { + ...gitKrakenProviderDescriptor, + type: lazy( + async () => + (await import(/* webpackChunkName: "ai" */ './gitkrakenProvider')).GitKrakenProvider, + ), + }, ], ] : ([] as any)), - ['openai', lazy(async () => (await import(/* webpackChunkName: "ai" */ './openaiProvider')).OpenAIProvider)], + [ + 'vscode', + { + ...vscodeProviderDescriptor, + type: lazy(async () => (await import(/* webpackChunkName: "ai" */ './vscodeProvider')).VSCodeAIProvider), + }, + ], + [ + 'openai', + { + ...openAIProviderDescriptor, + type: lazy(async () => (await import(/* webpackChunkName: "ai" */ './openaiProvider')).OpenAIProvider), + }, + ], [ 'anthropic', - lazy(async () => (await import(/* webpackChunkName: "ai" */ './anthropicProvider')).AnthropicProvider), + { + ...anthropicProviderDescriptor, + type: lazy( + async () => (await import(/* webpackChunkName: "ai" */ './anthropicProvider')).AnthropicProvider, + ), + }, + ], + [ + 'gemini', + { + ...geminiProviderDescriptor, + type: lazy(async () => (await import(/* webpackChunkName: "ai" */ './geminiProvider')).GeminiProvider), + }, + ], + [ + 'deepseek', + { + ...deepSeekProviderDescriptor, + type: lazy(async () => (await import(/* webpackChunkName: "ai" */ './deepSeekProvider')).DeepSeekProvider), + }, + ], + [ + 'xai', + { + ...xAIProviderDescriptor, + type: lazy(async () => (await import(/* webpackChunkName: "ai" */ './xaiProvider')).XAIProvider), + }, ], - ['gemini', lazy(async () => (await import(/* webpackChunkName: "ai" */ './geminiProvider')).GeminiProvider)], - ['deepseek', lazy(async () => (await import(/* webpackChunkName: "ai" */ './deepSeekProvider')).DeepSeekProvider)], - ['xai', lazy(async () => (await import(/* webpackChunkName: "ai" */ './xaiProvider')).XAIProvider)], [ 'github', - lazy(async () => (await import(/* webpackChunkName: "ai" */ './githubModelsProvider')).GitHubModelsProvider), + { + ...githubProviderDescriptor, + type: lazy( + async () => (await import(/* webpackChunkName: "ai" */ './githubModelsProvider')).GitHubModelsProvider, + ), + }, ], [ 'huggingface', - lazy(async () => (await import(/* webpackChunkName: "ai" */ './huggingFaceProvider')).HuggingFaceProvider), + { + ...huggingFaceProviderDescriptor, + type: lazy( + async () => (await import(/* webpackChunkName: "ai" */ './huggingFaceProvider')).HuggingFaceProvider, + ), + }, ], ]); -export interface AIModelChangeEvent { - readonly model: AIModel | undefined; -} - export class AIProviderService implements Disposable { private _model: AIModel | undefined; private _provider: AIProvider | undefined; @@ -126,10 +180,6 @@ export class AIProviderService implements Disposable { this._provider?.dispose(); } - get currentProviderId(): AIProviders | undefined { - return this._provider?.id; - } - private getConfiguredModel(): AIModelDescriptor | undefined { const qualifiedModelId = configuration.get('ai.model') ?? undefined; if (qualifiedModelId == null) return undefined; @@ -168,13 +218,13 @@ export class AIProviderService implements Disposable { }; if (providerId != null && this.supports(providerId)) { - const type = _supportedProviderTypes.get(providerId); + const type = supportedAIProviders.get(providerId)?.type; if (type == null) return []; return loadModels(type); } - const modelResults = await Promise.allSettled(map(_supportedProviderTypes.values(), t => loadModels(t))); + const modelResults = await Promise.allSettled(map(supportedAIProviders.values(), p => loadModels(p.type))); return modelResults.flatMap(m => getSettledValue(m, [])); } @@ -188,10 +238,44 @@ export class AIProviderService implements Disposable { if (options?.silent) return undefined; - const pick = await showAIModelPicker(this.container, cfg); - if (pick == null) return undefined; + let chosenProviderId: AIProviders | undefined; + let chosenModel: AIModel | undefined; + + if (!options?.force) { + const vsCodeModels = await this.getModels('vscode'); + if (vsCodeModels.length !== 0) { + chosenProviderId = 'vscode'; + } else if ((await this.container.subscription.getSubscription()).account?.verified) { + chosenProviderId = 'gitkraken'; + const gitkrakenModels = await this.getModels('gitkraken'); + chosenModel = gitkrakenModels.find(m => m.default); + } + } + + while (true) { + chosenProviderId ??= (await showAIProviderPicker(this.container, cfg))?.provider; + if (chosenProviderId == null) return; + + const provider = supportedAIProviders.get(chosenProviderId); + if (provider == null) return; + + if (!(await this.ensureProviderConfigured(provider, false))) return; + + if (chosenModel == null) { + const result = await showAIModelPicker(this.container, chosenProviderId, cfg); + if (result == null || (isDirective(result) && result !== Directive.Back)) return; + if (result === Directive.Back) { + chosenProviderId = undefined; + continue; + } + + chosenModel = result.model; + } + + break; + } - const model = await this.getOrUpdateModel(pick.model); + const model = await this.getOrUpdateModel(chosenModel); this.container.telemetry.sendEvent( 'ai/switchModel', @@ -205,9 +289,38 @@ export class AIProviderService implements Disposable { source, ); + void (await showConfirmAIProviderToS(this.container.storage)); return model; } + async getProvidersConfiguration(): Promise> { + const promises = await Promise.allSettled( + map( + supportedAIProviders.values(), + async p => + [ + p.id, + { ...p, type: undefined, configured: await this.ensureProviderConfigured(p, true) }, + ] as const, + ), + ); + return new Map(getSettledValues(promises)); + } + + private async ensureProviderConfigured(provider: AIProviderDescriptorWithType, silent: boolean): Promise { + if (provider.id === this._provider?.id) return this._provider.configured(silent); + + const type = await provider.type.value; + if (type == null) return false; + + const p = new type(this.container, this.connection); + try { + return await p.configured(silent); + } finally { + p.dispose(); + } + } + private getOrUpdateModel(model: AIModel): Promise; private getOrUpdateModel(providerId: T, modelId: string): Promise; private async getOrUpdateModel( @@ -230,7 +343,7 @@ export class AIProviderService implements Disposable { this._providerDisposable?.dispose(); this._provider?.dispose(); - const type = await _supportedProviderTypes.get(providerId)?.value; + const type = await supportedAIProviders.get(providerId)?.type.value; if (type == null) { this._provider = undefined; this._model = undefined; @@ -547,12 +660,7 @@ export class AIProviderService implements Disposable { progress?: ProgressOptions; }, ): Promise { - const { confirmed, model } = await getModelAndConfirmAIProviderToS( - 'diff', - source, - this, - this.container.storage, - ); + const model = await this.getModel(undefined, source); if (model == null) { options?.generating?.cancel(); return undefined; @@ -560,6 +668,7 @@ export class AIProviderService implements Disposable { const telementry = getTelemetryInfo(model); + const confirmed = await showConfirmAIProviderToS(this.container.storage); if (!confirmed) { this.container.telemetry.sendEvent( telementry.key, @@ -676,29 +785,46 @@ export class AIProviderService implements Disposable { } if (provider != null && result === resetCurrent) { - void env.clipboard.writeText((await this.container.storage.getSecret(`gitlens.${provider.id}.key`)) ?? ''); - void this.container.storage.deleteSecret(`gitlens.${provider.id}.key`); - - void this.container.storage.delete(`confirm:ai:tos:${provider.id}`); - void this.container.storage.deleteWorkspace(`confirm:ai:tos:${provider.id}`); + this.resetProviderKey(provider.id); + this.resetConfirmations(); } else if (result === resetAll) { const keys = []; - for (const [providerId] of _supportedProviderTypes) { + for (const providerId of supportedAIProviders.keys()) { keys.push(await this.container.storage.getSecret(`gitlens.${providerId}.key`)); + + this.resetProviderKey(providerId, true); } + + this.resetConfirmations(); + void env.clipboard.writeText(keys.join('\n')); + void window.showInformationMessage( + `All stored AI keys have been reset. The configured keys were copied to your clipboard.`, + ); + } + } - for (const [providerId] of _supportedProviderTypes) { - void this.container.storage.deleteSecret(`gitlens.${providerId}.key`); - } + resetConfirmations(): void { + void this.container.storage.deleteWithPrefix(`confirm:ai:tos`); + void this.container.storage.deleteWorkspaceWithPrefix(`confirm:ai:tos`); + } - void this.container.storage.deleteWithPrefix(`confirm:ai:tos`); - void this.container.storage.deleteWorkspaceWithPrefix(`confirm:ai:tos`); + resetProviderKey(provider: AIProviders, silent?: boolean): void { + if (!silent) { + void this.container.storage.getSecret(`gitlens.${provider}.key`).then(key => { + if (key) { + void env.clipboard.writeText(key); + void window.showInformationMessage( + `The stored AI key has been reset. The configured key was copied to your clipboard.`, + ); + } + }); } + void this.container.storage.deleteSecret(`gitlens.${provider}.key`); } supports(provider: AIProviders | string): boolean { - return _supportedProviderTypes.has(provider as AIProviders); + return supportedAIProviders.has(provider as AIProviders); } switchModel(source?: Source): Promise { @@ -706,60 +832,33 @@ export class AIProviderService implements Disposable { } } -async function getModelAndConfirmAIProviderToS( - confirmationType: 'data' | 'diff', - source: Source, - service: AIProviderService, - storage: Storage, -): Promise<{ confirmed: boolean; model: AIModel | undefined }> { - let model = await service.getModel(undefined, source); - while (true) { - if (model == null) return { confirmed: false, model: model }; - - const confirmed = - storage.get(`confirm:ai:tos:${model.provider.id}`, false) || - storage.getWorkspace(`confirm:ai:tos:${model.provider.id}`, false); - if (confirmed) return { confirmed: true, model: model }; - - const accept: MessageItem = { title: 'Continue' }; - const switchModel: MessageItem = { title: 'Switch Model' }; - const acceptWorkspace: MessageItem = { title: 'Always for this Workspace' }; - const acceptAlways: MessageItem = { title: 'Always' }; - const decline: MessageItem = { title: 'Cancel', isCloseAffordance: true }; - - const result = await window.showInformationMessage( - `GitLens AI features require sending ${ - confirmationType === 'data' ? 'data' : 'a diff of the code changes' - } to ${ - model.provider.name - } for analysis. This may contain sensitive information.\n\nDo you want to continue?`, - { modal: true }, - accept, - switchModel, - acceptWorkspace, - acceptAlways, - decline, - ); +async function showConfirmAIProviderToS(storage: Storage): Promise { + const confirmed = storage.get(`confirm:ai:tos`, false) || storage.getWorkspace(`confirm:ai:tos`, false); + if (confirmed) return true; - if (result === switchModel) { - model = await service.switchModel(source); - continue; - } + const acceptAlways: MessageItem = { title: 'Accept' }; + const acceptWorkspace: MessageItem = { title: 'Accept Only for this Workspace' }; + const cancel: MessageItem = { title: 'Cancel', isCloseAffordance: true }; - if (result === accept) return { confirmed: true, model: model }; + const result = await window.showInformationMessage( + 'GitLens AI features can send code snippets, diffs, and other context to your selected AI provider for analysis. This may contain sensitive information.', + { modal: true }, + acceptAlways, + acceptWorkspace, + cancel, + ); - if (result === acceptWorkspace) { - void storage.storeWorkspace(`confirm:ai:tos:${model.provider.id}`, true).catch(); - return { confirmed: true, model: model }; - } - - if (result === acceptAlways) { - void storage.store(`confirm:ai:tos:${model.provider.id}`, true).catch(); - return { confirmed: true, model: model }; - } + if (result === acceptWorkspace) { + void storage.storeWorkspace(`confirm:ai:tos`, true).catch(); + return true; + } - return { confirmed: false, model: model }; + if (result === acceptAlways) { + void storage.store(`confirm:ai:tos`, true).catch(); + return true; } + + return false; } function parseSummarizeResult(result: string): NonNullable { @@ -800,7 +899,7 @@ function splitMessageIntoSummaryAndBody(message: string): NonNullable { diff --git a/src/plus/ai/anthropicProvider.ts b/src/plus/ai/anthropicProvider.ts index ad9fe4c60a27e..3527ba7002f2d 100644 --- a/src/plus/ai/anthropicProvider.ts +++ b/src/plus/ai/anthropicProvider.ts @@ -1,10 +1,9 @@ import type { CancellationToken } from 'vscode'; import type { Response } from '@env/fetch'; +import { anthropicProviderDescriptor as provider } from '../../constants.ai'; import type { AIActionType, AIModel } from './models/model'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; -const provider = { id: 'anthropic', name: 'Anthropic' } as const; - type AnthropicModel = AIModel; const models: AnthropicModel[] = [ { @@ -106,6 +105,7 @@ const models: AnthropicModel[] = [ export class AnthropicProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = { keyUrl: 'https://console.anthropic.com/account/keys', keyValidator: /(?:sk-)?[a-zA-Z0-9-_]{32,}/, diff --git a/src/plus/ai/deepSeekProvider.ts b/src/plus/ai/deepSeekProvider.ts index e2d0a36d2419f..5ae5ef035e081 100644 --- a/src/plus/ai/deepSeekProvider.ts +++ b/src/plus/ai/deepSeekProvider.ts @@ -1,8 +1,7 @@ +import { deepSeekProviderDescriptor as provider } from '../../constants.ai'; import type { AIModel } from './models/model'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; -const provider = { id: 'deepseek', name: 'DeepSeek' } as const; - type DeepSeekModel = AIModel; const models: DeepSeekModel[] = [ { @@ -25,6 +24,7 @@ const models: DeepSeekModel[] = [ export class DeepSeekProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = { keyUrl: 'https://platform.deepseek.com/api_keys', keyValidator: /(?:sk-)?[a-zA-Z0-9]{32,}/, diff --git a/src/plus/ai/geminiProvider.ts b/src/plus/ai/geminiProvider.ts index 578c3ac94f781..ad85ef104804a 100644 --- a/src/plus/ai/geminiProvider.ts +++ b/src/plus/ai/geminiProvider.ts @@ -1,10 +1,9 @@ import type { CancellationToken } from 'vscode'; import type { Response } from '@env/fetch'; +import { geminiProviderDescriptor as provider } from '../../constants.ai'; import type { AIActionType, AIModel } from './models/model'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; -const provider = { id: 'gemini', name: 'Google' } as const; - type GeminiModel = AIModel; const models: GeminiModel[] = [ { @@ -102,6 +101,7 @@ const models: GeminiModel[] = [ export class GeminiProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = { keyUrl: 'https://aistudio.google.com/app/apikey', }; diff --git a/src/plus/ai/githubModelsProvider.ts b/src/plus/ai/githubModelsProvider.ts index 0b3a206ca2342..c75169fee5f3b 100644 --- a/src/plus/ai/githubModelsProvider.ts +++ b/src/plus/ai/githubModelsProvider.ts @@ -1,16 +1,16 @@ import type { Response } from '@env/fetch'; import { fetch } from '@env/fetch'; +import { githubProviderDescriptor as provider } from '../../constants.ai'; import type { AIActionType, AIModel } from './models/model'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; import { getMaxCharacters } from './utils/-webview/ai.utils'; -const provider = { id: 'github', name: 'GitHub Models' } as const; - type GitHubModelsModel = AIModel; export class GitHubModelsProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = { keyUrl: 'https://github.com/settings/tokens', keyValidator: /(?:ghp-)?[a-zA-Z0-9]{32,}/, diff --git a/src/plus/ai/gitkrakenProvider.ts b/src/plus/ai/gitkrakenProvider.ts index 31158588bae40..d4d0b5a1a85a6 100644 --- a/src/plus/ai/gitkrakenProvider.ts +++ b/src/plus/ai/gitkrakenProvider.ts @@ -1,5 +1,6 @@ import type { Disposable } from 'vscode'; import { fetch } from '@env/fetch'; +import { gitKrakenProviderDescriptor as provider } from '../../constants.ai'; import type { Container } from '../../container'; import { AuthenticationRequiredError } from '../../errors'; import { debug } from '../../system/decorators/log'; @@ -10,15 +11,14 @@ import type { ServerConnection } from '../gk/serverConnection'; import type { AIActionType, AIModel } from './models/model'; import type { PromptTemplate } from './models/promptTemplates'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; -import { getActionName } from './utils/-webview/ai.utils'; - -const provider = { id: 'gitkraken', name: 'GitKraken AI (Preview)' } as const; +import { ensureAccount, getActionName } from './utils/-webview/ai.utils'; type GitKrakenModel = AIModel; export class GitKrakenProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = {}; private readonly _disposable: Disposable; @@ -141,8 +141,16 @@ export class GitKrakenProvider extends OpenAICompatibleProvider { - return Promise.resolve(''); + protected override async getApiKey(silent: boolean): Promise { + let session = await this.container.subscription.getAuthenticationSession(); + if (session?.accessToken) return session.accessToken; + if (silent) return undefined; + + const result = await ensureAccount(this.container, silent); + if (!result) return undefined; + + session = await this.container.subscription.getAuthenticationSession(); + return session?.accessToken; } protected getUrl(_model: AIModel): string { @@ -153,9 +161,9 @@ export class GitKrakenProvider extends OpenAICompatibleProvider, _url: string, - _apiKey: string, + apiKey: string, ): Promise> { - return this.connection.getGkHeaders(undefined, undefined, { + return this.connection.getGkHeaders(apiKey, undefined, { Accept: 'application/json', 'GK-Action': action, }); diff --git a/src/plus/ai/huggingFaceProvider.ts b/src/plus/ai/huggingFaceProvider.ts index d9ec7df6b2f87..7e73344bd7632 100644 --- a/src/plus/ai/huggingFaceProvider.ts +++ b/src/plus/ai/huggingFaceProvider.ts @@ -1,14 +1,14 @@ import { fetch } from '@env/fetch'; +import { huggingFaceProviderDescriptor as provider } from '../../constants.ai'; import type { AIModel } from './models/model'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; -const provider = { id: 'huggingface', name: 'Hugging Face' } as const; - type HuggingFaceModel = AIModel; export class HuggingFaceProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = { keyUrl: 'https://huggingface.co/settings/tokens', keyValidator: /(?:hf_)?[a-zA-Z0-9]{32,}/, diff --git a/src/plus/ai/models/model.ts b/src/plus/ai/models/model.ts index 423f09b391593..aa7bd26d7aa9e 100644 --- a/src/plus/ai/models/model.ts +++ b/src/plus/ai/models/model.ts @@ -1,4 +1,8 @@ -import type { AIProviders } from '../../../constants.ai'; +import type { AIPrimaryProviders, AIProviders } from '../../../constants.ai'; +import type { Container } from '../../../container'; +import type { Lazy } from '../../../system/lazy'; +import type { ServerConnection } from '../../gk/serverConnection'; +import type { AIProvider } from './provider'; export interface AIModel { readonly id: Model; @@ -26,3 +30,27 @@ export type AIActionType = | 'generate-changelog' | `generate-create-${'cloudPatch' | 'codeSuggestion' | 'pullRequest'}` | `explain-changes`; + +export interface AIProviderConstructor { + new (container: Container, connection: ServerConnection): AIProvider; +} + +export interface AIProviderDescriptor { + readonly id: T; + readonly name: string; + readonly primary: T extends AIPrimaryProviders ? true : false; + readonly requiresAccount: boolean; + readonly requiresUserKey: boolean; + + readonly type?: never; +} + +export interface AIProviderDescriptorWithConfiguration + extends AIProviderDescriptor { + readonly configured: boolean; +} + +export interface AIProviderDescriptorWithType + extends Omit, 'type'> { + readonly type: Lazy>>; +} diff --git a/src/plus/ai/models/provider.ts b/src/plus/ai/models/provider.ts index b5af5fb11cb68..882ecf3bf8fc6 100644 --- a/src/plus/ai/models/provider.ts +++ b/src/plus/ai/models/provider.ts @@ -27,6 +27,7 @@ export interface AIProvider extends onDidChange?: Event; + configured(silent: boolean): Promise; getModels(): Promise[]>; getPromptTemplate(action: AIActionType, model: AIModel): Promise; diff --git a/src/plus/ai/openAICompatibleProvider.ts b/src/plus/ai/openAICompatibleProvider.ts index 60d900fb0caa4..689f424955c2c 100644 --- a/src/plus/ai/openAICompatibleProvider.ts +++ b/src/plus/ai/openAICompatibleProvider.ts @@ -9,7 +9,7 @@ import { sum } from '../../system/iterable'; import { getLoggableName, Logger } from '../../system/logger'; import { startLogScope } from '../../system/logger.scope'; import type { ServerConnection } from '../gk/serverConnection'; -import type { AIActionType, AIModel } from './models/model'; +import type { AIActionType, AIModel, AIProviderDescriptor } from './models/model'; import type { PromptTemplate, PromptTemplateContext } from './models/promptTemplates'; import type { AIProvider, AIRequestResult } from './models/provider'; import { @@ -36,8 +36,13 @@ export abstract class OpenAICompatibleProvider implements abstract readonly id: T; abstract readonly name: string; + protected abstract readonly descriptor: AIProviderDescriptor; protected abstract readonly config: { keyUrl?: string; keyValidator?: RegExp }; + async configured(silent: boolean): Promise { + return (await this.getApiKey(silent)) != null; + } + abstract getModels(): Promise[]>; async getPromptTemplate( action: TAction, @@ -48,15 +53,20 @@ export abstract class OpenAICompatibleProvider implements protected abstract getUrl(_model: AIModel): string; - protected async getApiKey(): Promise { + protected async getApiKey(silent: boolean): Promise { const { keyUrl, keyValidator } = this.config; - return getOrPromptApiKey(this.container.storage, { - id: this.id, - name: this.name, - validator: keyValidator != null ? v => keyValidator.test(v) : () => true, - url: keyUrl, - }); + return getOrPromptApiKey( + this.container, + { + id: this.id, + name: this.name, + requiresAccount: this.descriptor.requiresAccount, + validator: keyValidator != null ? v => keyValidator.test(v) : () => true, + url: keyUrl, + }, + silent, + ); } protected getHeaders( @@ -81,7 +91,7 @@ export abstract class OpenAICompatibleProvider implements ): Promise { using scope = startLogScope(`${getLoggableName(this)}.sendRequest`, false); - const apiKey = await this.getApiKey(); + const apiKey = await this.getApiKey(false); if (apiKey == null) return undefined; const prompt = await this.getPromptTemplate(action, model); diff --git a/src/plus/ai/openaiProvider.ts b/src/plus/ai/openaiProvider.ts index 7a49e7d461c1f..0595b54f0f9e9 100644 --- a/src/plus/ai/openaiProvider.ts +++ b/src/plus/ai/openaiProvider.ts @@ -1,9 +1,8 @@ +import { openAIProviderDescriptor as provider } from '../../constants.ai'; import { configuration } from '../../system/-webview/configuration'; import type { AIActionType, AIModel } from './models/model'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; -const provider = { id: 'openai', name: 'OpenAI' } as const; - type OpenAIModel = AIModel; const models: OpenAIModel[] = [ { @@ -211,6 +210,7 @@ const models: OpenAIModel[] = [ export class OpenAIProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = { keyUrl: 'https://platform.openai.com/account/api-keys', keyValidator: /(?:sk-)?[a-zA-Z0-9]{32,}/, diff --git a/src/plus/ai/utils/-webview/ai.utils.ts b/src/plus/ai/utils/-webview/ai.utils.ts index 2b4f721495514..ea0eaf6153134 100644 --- a/src/plus/ai/utils/-webview/ai.utils.ts +++ b/src/plus/ai/utils/-webview/ai.utils.ts @@ -1,12 +1,26 @@ import type { Disposable, QuickInputButton } from 'vscode'; import { env, ThemeIcon, Uri, window } from 'vscode'; import type { AIProviders } from '../../../../constants.ai'; +import type { Container } from '../../../../container'; +import { createDirectiveQuickPickItem, Directive } from '../../../../quickpicks/items/directive'; import { configuration } from '../../../../system/-webview/configuration'; -import type { Storage } from '../../../../system/-webview/storage'; import { formatNumeric } from '../../../../system/date'; import { getPossessiveForm } from '../../../../system/string'; +import { ensureAccountQuickPick } from '../../../gk/utils/-webview/acount.utils'; import type { AIActionType, AIModel } from '../../models/model'; +export function ensureAccount(container: Container, silent: boolean): Promise { + return ensureAccountQuickPick( + container, + createDirectiveQuickPickItem(Directive.Noop, undefined, { + label: 'Use AI-powered GitLens features like Generate Commit Message, Explain Commit, and more', + iconPath: new ThemeIcon('sparkle'), + }), + { source: 'ai' }, + silent, + ); +} + export function getActionName(action: AIActionType): string { switch (action) { case 'generate-commitMessage': @@ -33,68 +47,81 @@ export function getMaxCharacters(model: AIModel, outputLength: number, overrideI } export async function getOrPromptApiKey( - storage: Storage, - provider: { id: AIProviders; name: string; validator: (value: string) => boolean; url?: string }, + container: Container, + provider: { + readonly id: AIProviders; + readonly name: string; + readonly requiresAccount: boolean; + readonly validator: (value: string) => boolean; + readonly url?: string; + }, + silent?: boolean, ): Promise { - let apiKey = await storage.getSecret(`gitlens.${provider.id}.key`); - if (!apiKey) { - const input = window.createInputBox(); - input.ignoreFocusOut = true; - - const disposables: Disposable[] = []; - - try { - const infoButton: QuickInputButton = { - iconPath: new ThemeIcon(`link-external`), - tooltip: `Open the ${provider.name} API Key Page`, - }; - - apiKey = await new Promise(resolve => { - disposables.push( - input.onDidHide(() => resolve(undefined)), - input.onDidChangeValue(value => { - if (value && !provider.validator(value)) { - input.validationMessage = `Please enter a valid ${provider.name} API key`; - return; - } - input.validationMessage = undefined; - }), - input.onDidAccept(() => { - const value = input.value.trim(); - if (!value || !provider.validator(value)) { - input.validationMessage = `Please enter a valid ${provider.name} API key`; - return; - } - - resolve(value); - }), - input.onDidTriggerButton(e => { - if (e === infoButton && provider.url) { - void env.openExternal(Uri.parse(provider.url)); - } - }), - ); - - input.password = true; - input.title = `Connect to ${provider.name}`; - input.placeholder = `Please enter your ${provider.name} API key to use this feature`; - input.prompt = `Enter your [${provider.name} API Key](${provider.url} "Get your ${provider.name} API key")`; - if (provider.url) { - input.buttons = [infoButton]; - } - - input.show(); - }); - } finally { - input.dispose(); - disposables.forEach(d => void d.dispose()); - } - - if (!apiKey) return undefined; - - void storage.storeSecret(`gitlens.${provider.id}.key`, apiKey).catch(); + let apiKey = await container.storage.getSecret(`gitlens.${provider.id}.key`); + if (apiKey) return apiKey; + if (silent) return undefined; + + if (provider.requiresAccount) { + const result = await ensureAccount(container, false); + if (!result) return undefined; + } + + const input = window.createInputBox(); + input.ignoreFocusOut = true; + + const disposables: Disposable[] = []; + + try { + const infoButton: QuickInputButton = { + iconPath: new ThemeIcon(`link-external`), + tooltip: `Open the ${provider.name} API Key Page`, + }; + + apiKey = await new Promise(resolve => { + disposables.push( + input.onDidHide(() => resolve(undefined)), + input.onDidChangeValue(value => { + if (value && !provider.validator(value)) { + input.validationMessage = `Please enter a valid ${provider.name} API key`; + return; + } + input.validationMessage = undefined; + }), + input.onDidAccept(() => { + const value = input.value.trim(); + if (!value || !provider.validator(value)) { + input.validationMessage = `Please enter a valid ${provider.name} API key`; + return; + } + + resolve(value); + }), + input.onDidTriggerButton(e => { + if (e === infoButton && provider.url) { + void env.openExternal(Uri.parse(provider.url)); + } + }), + ); + + input.password = true; + input.title = `Connect to ${provider.name}`; + input.placeholder = `Please enter your ${provider.name} API key to use this feature`; + input.prompt = `Enter your [${provider.name} API Key](${provider.url} "Get your ${provider.name} API key")`; + if (provider.url) { + input.buttons = [infoButton]; + } + + input.show(); + }); + } finally { + input.dispose(); + disposables.forEach(d => void d.dispose()); } + if (!apiKey) return undefined; + + void container.storage.storeSecret(`gitlens.${provider.id}.key`, apiKey).catch(); + return apiKey; } diff --git a/src/plus/ai/vscodeProvider.ts b/src/plus/ai/vscodeProvider.ts index eaf634f46aef8..bf69a2df17ef9 100644 --- a/src/plus/ai/vscodeProvider.ts +++ b/src/plus/ai/vscodeProvider.ts @@ -1,5 +1,6 @@ import type { CancellationToken, Event, LanguageModelChat, LanguageModelChatSelector } from 'vscode'; import { CancellationTokenSource, Disposable, EventEmitter, LanguageModelChatMessage, lm } from 'vscode'; +import { vscodeProviderDescriptor } from '../../constants.ai'; import type { TelemetryEvents } from '../../constants.telemetry'; import type { Container } from '../../container'; import { sum } from '../../system/iterable'; @@ -13,7 +14,7 @@ import type { AIProvider, AIRequestResult } from './models/provider'; import { getMaxCharacters, getValidatedTemperature, showDiffTruncationWarning } from './utils/-webview/ai.utils'; import { getLocalPromptTemplate, resolvePrompt } from './utils/-webview/prompt.utils'; -const provider = { id: 'vscode', name: 'VS Code Provided' } as const; +const provider = vscodeProviderDescriptor; type VSCodeAIModel = AIModel & { vendor: string; selector: LanguageModelChatSelector }; @@ -49,6 +50,10 @@ export class VSCodeAIProvider implements AIProvider { this._disposable.dispose(); } + async configured(_silent: boolean): Promise { + return (await this.getModels()).length !== 0; + } + async getModels(): Promise[]> { const models = await lm.selectChatModels(); return models.map(getModelFromChatModel); @@ -158,7 +163,7 @@ export class VSCodeAIProvider implements AIProvider { function getModelFromChatModel(model: LanguageModelChat): VSCodeAIModel { return { id: `${model.vendor}:${model.family}`, - name: `${capitalize(model.vendor)} ${model.name}`, + name: model.vendor === 'copilot' ? model.name : `${capitalize(model.vendor)} ${model.name}`, vendor: model.vendor, selector: { vendor: model.vendor, diff --git a/src/plus/ai/xaiProvider.ts b/src/plus/ai/xaiProvider.ts index 0652acc229ae4..a71fe9e012292 100644 --- a/src/plus/ai/xaiProvider.ts +++ b/src/plus/ai/xaiProvider.ts @@ -1,8 +1,7 @@ +import { xAIProviderDescriptor as provider } from '../../constants.ai'; import type { AIModel } from './models/model'; import { OpenAICompatibleProvider } from './openAICompatibleProvider'; -const provider = { id: 'xai', name: 'xAI' } as const; - type XAIModel = AIModel; const models: XAIModel[] = [ { @@ -17,6 +16,7 @@ const models: XAIModel[] = [ export class XAIProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; + protected readonly descriptor = provider; protected readonly config = { keyUrl: 'https://console.x.ai/', keyValidator: /(?:xai-)?[a-zA-Z0-9]{32,}/, diff --git a/src/plus/gk/utils/-webview/acount.utils.ts b/src/plus/gk/utils/-webview/acount.utils.ts index 59432cc73353a..c986186c776a3 100644 --- a/src/plus/gk/utils/-webview/acount.utils.ts +++ b/src/plus/gk/utils/-webview/acount.utils.ts @@ -3,6 +3,9 @@ import { window } from 'vscode'; import type { Source } from '../../../../constants.telemetry'; import type { Container } from '../../../../container'; import type { PlusFeatures } from '../../../../features'; +import { createQuickPickSeparator } from '../../../../quickpicks/items/common'; +import type { DirectiveQuickPickItem } from '../../../../quickpicks/items/directive'; +import { createDirectiveQuickPickItem, Directive } from '../../../../quickpicks/items/directive'; export async function ensureAccount(container: Container, title: string, source: Source): Promise { while (true) { @@ -55,6 +58,67 @@ export async function ensureAccount(container: Container, title: string, source: return true; } +export async function ensureAccountQuickPick( + container: Container, + descriptionItem: DirectiveQuickPickItem, + source: Source, + silent?: boolean, +): Promise { + while (true) { + const account = (await container.subscription.getSubscription()).account; + if (account?.verified === true) break; + + if (silent) return false; + + const directives: DirectiveQuickPickItem[] = [descriptionItem]; + + let placeholder = 'Requires an account to continue'; + if (account?.verified === false) { + directives.push( + createDirectiveQuickPickItem(Directive.RequiresVerification, true), + createQuickPickSeparator(), + createDirectiveQuickPickItem(Directive.Cancel), + ); + placeholder = 'You must verify your email before you can continue'; + } else { + directives.push( + createDirectiveQuickPickItem(Directive.StartProTrial, true), + createDirectiveQuickPickItem(Directive.SignIn), + createQuickPickSeparator(), + createDirectiveQuickPickItem(Directive.Cancel), + ); + } + + const result = await window.showQuickPick(directives, { + placeHolder: placeholder, + ignoreFocusOut: true, + }); + + if (result == null) return false; + if (result.directive === Directive.Noop) continue; + + if (result.directive === Directive.RequiresVerification) { + if (await container.subscription.resendVerification(source)) { + continue; + } + } + if (result.directive === Directive.StartProTrial) { + if (await container.subscription.loginOrSignUp(true, source)) { + continue; + } + } + if (result.directive === Directive.SignIn) { + if (await container.subscription.loginOrSignUp(false, source)) { + continue; + } + } + + return false; + } + + return true; +} + export async function ensureFeatureAccess( container: Container, title: string, diff --git a/src/quickpicks/aiModelPicker.ts b/src/quickpicks/aiModelPicker.ts index 741e6b505b65f..e0d95bcc8bb9d 100644 --- a/src/quickpicks/aiModelPicker.ts +++ b/src/quickpicks/aiModelPicker.ts @@ -1,38 +1,146 @@ import type { Disposable, QuickInputButton, QuickPickItem } from 'vscode'; -import { QuickPickItemKind, ThemeIcon, window } from 'vscode'; +import { QuickInputButtons, ThemeIcon, window } from 'vscode'; import type { AIProviders } from '../constants.ai'; import type { Container } from '../container'; import type { AIModel, AIModelDescriptor } from '../plus/ai/models/model'; -import { executeCommand } from '../system/-webview/command'; +import { isSubscriptionPaidPlan } from '../plus/gk/utils/subscription.utils'; import { getQuickPickIgnoreFocusOut } from '../system/-webview/vscode'; +import { getSettledValue } from '../system/promise'; +import { createQuickPickSeparator } from './items/common'; +import { Directive } from './items/directive'; export interface ModelQuickPickItem extends QuickPickItem { model: AIModel; } +export interface ProviderQuickPickItem extends QuickPickItem { + provider: AIProviders; +} + +const ClearAIKeyButton: QuickInputButton = { + iconPath: new ThemeIcon('trash'), + tooltip: 'Clear AI Key', +}; + +const ConfigureAIKeyButton: QuickInputButton = { + iconPath: new ThemeIcon('key'), + tooltip: 'Configure AI Key...', +}; + +export async function showAIProviderPicker( + container: Container, + current: AIModelDescriptor | undefined, +): Promise { + const [providersResult, modelResult, subscriptionResult] = await Promise.allSettled([ + container.ai.getProvidersConfiguration(), + container.ai.getModel({ silent: true }, { source: 'ai:picker' }), + container.subscription.getSubscription(), + ]); + + const providers = getSettledValue(providersResult) ?? new Map(); + const currentModelName = getSettledValue(modelResult)?.name; + const subscription = getSettledValue(subscriptionResult)!; + const hasPaidPlan = isSubscriptionPaidPlan(subscription.plan.effective.id) && subscription.account?.verified; + + const quickpick = window.createQuickPick(); + quickpick.ignoreFocusOut = getQuickPickIgnoreFocusOut(); + quickpick.title = 'Select AI Provider'; + quickpick.placeholder = 'Choose an AI provider to use'; + + const disposables: Disposable[] = []; + + try { + const pickedProvider = + current?.provider ?? providers.get('vscode')?.configured + ? 'vscode' + : providers.get('gitkraken')?.configured + ? 'gitkraken' + : undefined; + + let addedRequiredKeySeparator = false; + const items: ProviderQuickPickItem[] = []; + + for (const p of providers.values()) { + if (!p.primary && !addedRequiredKeySeparator) { + addedRequiredKeySeparator = true; + items.push(createQuickPickSeparator('Requires API Key')); + } + + items.push({ + label: p.name, + iconPath: p.id === current?.provider ? new ThemeIcon('check') : new ThemeIcon('blank'), + provider: p.id, + picked: p.id === pickedProvider, + detail: + p.id === current?.provider && currentModelName + ? ` ${currentModelName}` + : p.id === 'gitkraken' + ? ' Models provided by GitKraken' + : undefined, + buttons: !p.primary ? (p.configured ? [ClearAIKeyButton] : [ConfigureAIKeyButton]) : undefined, + description: + p.id === 'gitkraken' + ? hasPaidPlan + ? ' included in your plan' + : ' included in GitLens Pro' + : undefined, + } satisfies ProviderQuickPickItem); + } + + while (true) { + const pick = await new Promise(resolve => { + disposables.push( + quickpick.onDidHide(() => resolve(undefined)), + quickpick.onDidAccept(() => { + if (quickpick.activeItems.length !== 0) { + resolve(quickpick.activeItems[0]); + } + }), + quickpick.onDidTriggerItemButton(e => { + if (e.button === ClearAIKeyButton) { + container.ai.resetProviderKey(e.item.provider); + providers.set(e.item.provider, { ...providers.get(e.item.provider)!, configured: false }); + resolve('refresh'); + } else if (e.button === ConfigureAIKeyButton) { + resolve(e.item); + } + }), + ); + + quickpick.items = items; + quickpick.activeItems = items.filter(i => i.picked); + + quickpick.show(); + }); + + if (pick === 'refresh') continue; + + return pick; + } + } finally { + quickpick.dispose(); + disposables.forEach(d => void d.dispose()); + } +} + export async function showAIModelPicker( container: Container, + provider: AIProviders, current?: AIModelDescriptor, -): Promise { - const models = (await container.ai.getModels()) ?? []; +): Promise { + const models = (await container.ai.getModels(provider)) ?? []; const items: ModelQuickPickItem[] = []; - let lastProvider: AIProviders | undefined; for (const m of models) { if (m.hidden) continue; - if (lastProvider !== m.provider.id) { - lastProvider = m.provider.id; - items.push({ label: m.provider.name, kind: QuickPickItemKind.Separator } as unknown as ModelQuickPickItem); - } - const picked = m.provider.id === current?.provider && m.id === current?.model; items.push({ label: m.name, + description: m.default ? ' recommended' : undefined, iconPath: picked ? new ThemeIcon('check') : new ThemeIcon('blank'), - // description: ` ~${formatNumeric(m.maxTokens)} tokens`, model: m, picked: picked, } satisfies ModelQuickPickItem); @@ -43,13 +151,8 @@ export async function showAIModelPicker( const disposables: Disposable[] = []; - const ResetAIKeyButton: QuickInputButton = { - iconPath: new ThemeIcon('clear-all'), - tooltip: 'Reset AI Keys...', - }; - try { - const pick = await new Promise(resolve => { + const pick = await new Promise(resolve => { disposables.push( quickpick.onDidHide(() => resolve(undefined)), quickpick.onDidAccept(() => { @@ -58,8 +161,8 @@ export async function showAIModelPicker( } }), quickpick.onDidTriggerButton(e => { - if (e === ResetAIKeyButton) { - void executeCommand('gitlens.resetAIKey'); + if (e === QuickInputButtons.Back) { + resolve(Directive.Back); } }), ); @@ -68,8 +171,9 @@ export async function showAIModelPicker( quickpick.placeholder = 'Choose an AI model to use'; quickpick.matchOnDescription = true; quickpick.matchOnDetail = true; - quickpick.buttons = [ResetAIKeyButton]; quickpick.items = items; + quickpick.activeItems = items.filter(i => i.picked); + quickpick.buttons = [QuickInputButtons.Back]; quickpick.show(); }); diff --git a/src/system/promise.ts b/src/system/promise.ts index a990ba46a3a8f..48a0ff94a1618 100644 --- a/src/system/promise.ts +++ b/src/system/promise.ts @@ -187,6 +187,12 @@ export function getSettledValue( return promise?.status === 'fulfilled' ? promise.value : defaultValue; } +export function getSettledValues( + promises: readonly PromiseSettledResult[], +): T[] { + return promises.map(getSettledValue).filter((v): v is T => v != null); +} + export function isPromise(obj: PromiseLike | T): obj is Promise { return obj != null && (obj instanceof Promise || typeof (obj as PromiseLike)?.then === 'function'); } diff --git a/src/webviews/apps/commitDetails/components/gl-commit-details.ts b/src/webviews/apps/commitDetails/components/gl-commit-details.ts index 539f20be4f630..420202ec3d0fe 100644 --- a/src/webviews/apps/commitDetails/components/gl-commit-details.ts +++ b/src/webviews/apps/commitDetails/components/gl-commit-details.ts @@ -439,9 +439,12 @@ export class GlCommitDetails extends GlDetailsBase { return html` Explain (AI) - - +
diff --git a/src/webviews/apps/plus/patchDetails/components/gl-draft-details.ts b/src/webviews/apps/plus/patchDetails/components/gl-draft-details.ts index 25fed917bf93f..218d987ad2272 100644 --- a/src/webviews/apps/plus/patchDetails/components/gl-draft-details.ts +++ b/src/webviews/apps/plus/patchDetails/components/gl-draft-details.ts @@ -201,9 +201,12 @@ export class GlDraftDetails extends GlTreeBase { return html` Explain (AI) - - +
diff --git a/src/webviews/apps/plus/shared/components/integrations-chip.ts b/src/webviews/apps/plus/shared/components/integrations-chip.ts index fa4283f3b5ea2..67dc0b2d07ae6 100644 --- a/src/webviews/apps/plus/shared/components/integrations-chip.ts +++ b/src/webviews/apps/plus/shared/components/integrations-chip.ts @@ -368,8 +368,8 @@ export class GLIntegrationsChip extends LitElement { source: 'home', detail: 'integrations', })}" - tooltip="Switch AI Model" - aria-label="Switch AI Model" + tooltip="Switch AI Provider/Model" + aria-label="Switch AI Provider/Model" > ` diff --git a/src/webviews/home/homeWebview.ts b/src/webviews/home/homeWebview.ts index c5061c29aee8d..f21c5025752fc 100644 --- a/src/webviews/home/homeWebview.ts +++ b/src/webviews/home/homeWebview.ts @@ -689,7 +689,7 @@ export class HomeWebviewProvider implements WebviewProvider