diff --git a/CHANGELOG.md b/CHANGELOG.md index 3be593da9c9e5..957d1008b08cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/) and this p ## [Unreleased] +### Added + +- Adds AI model status and model switcher to the _Home_ view ([#4064](https://github.com/gitkraken/vscode-gitlens/issues/4064)) + ### Fixed - Fixes Settings editor breaking when dragging it to a new tab group ([#4061](https://github.com/gitkraken/vscode-gitlens/issues/4061)) diff --git a/docs/telemetry-events.md b/docs/telemetry-events.md index b73908dc85455..29b478f6ad43a 100644 --- a/docs/telemetry-events.md +++ b/docs/telemetry-events.md @@ -110,7 +110,7 @@ ### ai/explain -> Sent when explaining changes from wip, commits, stashes, patches,etc. +> Sent when explaining changes from wip, commits, stashes, patches, etc. ```typescript { @@ -199,6 +199,26 @@ or } ``` +### ai/switchModel + +> Sent when switching ai models + +```typescript +{ + 'model.id': string, + 'model.provider.id': 'anthropic' | 'deepseek' | 'gemini' | 'github' | 'huggingface' | 'openai' | 'vscode' | 'xai', + 'model.provider.name': string +} +``` + +or + +```typescript +{ + 'failed': true +} +``` + ### associateIssueWithBranch/action > Sent when the user chooses to manage integrations diff --git a/src/ai/aiProviderService.ts b/src/ai/aiProviderService.ts index 809095ea0d7d1..44e1755661385 100644 --- a/src/ai/aiProviderService.ts +++ b/src/ai/aiProviderService.ts @@ -1,7 +1,7 @@ -import type { CancellationToken, Disposable, MessageItem, ProgressOptions, QuickInputButton } from 'vscode'; -import { env, ThemeIcon, Uri, window } from 'vscode'; +import type { CancellationToken, Disposable, Event, MessageItem, ProgressOptions, QuickInputButton } from 'vscode'; +import { env, EventEmitter, ThemeIcon, Uri, window } from 'vscode'; import type { AIProviders, SupportedAIModels, VSCodeAIModels } from '../constants.ai'; -import type { AIGenerateDraftEventData, Sources, TelemetryEvents } from '../constants.telemetry'; +import type { AIGenerateDraftEventData, Source, TelemetryEvents } from '../constants.telemetry'; import type { Container } from '../container'; import { CancellationError } from '../errors'; import type { GitCommit } from '../git/models/commit'; @@ -15,37 +15,36 @@ import { configuration } from '../system/-webview/configuration'; import type { Storage } from '../system/-webview/storage'; import { supportedInVSCodeVersion } from '../system/-webview/vscode'; import { formatNumeric } from '../system/date'; +import { debounce } from '../system/function'; +import { map } from '../system/iterable'; import type { Lazy } from '../system/lazy'; +import { lazy } from '../system/lazy'; import type { Deferred } from '../system/promise'; import { getSettledValue } from '../system/promise'; import { getPossessiveForm } from '../system/string'; -import type { TelemetryService } from '../telemetry/telemetry'; -import { AnthropicProvider } from './anthropicProvider'; -import { DeepSeekProvider } from './deepSeekProvider'; -import { GeminiProvider } from './geminiProvider'; -import { GitHubModelsProvider } from './githubModelsProvider'; -import { HuggingFaceProvider } from './huggingFaceProvider'; -import { OpenAIProvider } from './openaiProvider'; -import { isVSCodeAIModel, VSCodeAIProvider } from './vscodeProvider'; -import { xAIProvider } from './xaiProvider'; export interface AIResult { - summary: string; - body: string; + readonly summary: string; + readonly body: string; } export interface AIGenerateChangelogChange { - message: string; - issues: { id: string; url: string; title: string | undefined }[]; + readonly message: string; + readonly issues: readonly { readonly id: string; readonly url: string; readonly title: string | undefined }[]; +} + +export interface AIModelDescriptor { + readonly provider: Provider; + readonly model: Model; } export interface AIModel { readonly id: Model; readonly name: string; - readonly maxTokens: { input: number; output: number }; + readonly maxTokens: { readonly input: number; readonly output: number }; readonly provider: { - id: Provider; - name: string; + readonly id: Provider; + readonly name: string; }; readonly default?: boolean; @@ -59,21 +58,39 @@ interface AIProviderConstructor { } // Order matters for sorting the picker -const _supportedProviderTypes = new Map([ - ...(supportedInVSCodeVersion('language-models') ? [['vscode', VSCodeAIProvider]] : ([] as any)), - ['openai', OpenAIProvider], - ['anthropic', AnthropicProvider], - ['gemini', GeminiProvider], - ['deepseek', DeepSeekProvider], - ['xai', xAIProvider], - ['github', GitHubModelsProvider], - ['huggingface', HuggingFaceProvider], +const _supportedProviderTypes = new Map>>([ + ...(supportedInVSCodeVersion('language-models') + ? [ + [ + 'vscode', + lazy(async () => (await import(/* webpackChunkName: "ai" */ './vscodeProvider')).VSCodeAIProvider), + ], + ] + : ([] as any)), + ['openai', lazy(async () => (await import(/* webpackChunkName: "ai" */ './openaiProvider')).OpenAIProvider)], + [ + 'anthropic', + lazy(async () => (await import(/* webpackChunkName: "ai" */ './anthropicProvider')).AnthropicProvider), + ], + ['gemini', lazy(async () => (await import(/* webpackChunkName: "ai" */ './geminiProvider')).GeminiProvider)], + ['deepseek', lazy(async () => (await import(/* webpackChunkName: "ai" */ './deepSeekProvider')).DeepSeekProvider)], + ['xai', lazy(async () => (await import(/* webpackChunkName: "ai" */ './xaiProvider')).XAIProvider)], + [ + 'github', + lazy(async () => (await import(/* webpackChunkName: "ai" */ './githubModelsProvider')).GitHubModelsProvider), + ], + [ + 'huggingface', + lazy(async () => (await import(/* webpackChunkName: "ai" */ './huggingFaceProvider')).HuggingFaceProvider), + ], ]); export interface AIProvider extends Disposable { readonly id: Provider; readonly name: string; + onDidChange?: Event; + getModels(): Promise[]>; explainChanges( @@ -109,9 +126,19 @@ export interface AIProvider extends ): Promise; } +export interface AIModelChangeEvent { + readonly model: AIModel | undefined; +} + export class AIProviderService implements Disposable { - private _provider: AIProvider | undefined; private _model: AIModel | undefined; + private _provider: AIProvider | undefined; + private _providerDisposable: Disposable | undefined; + + private readonly _onDidChangeModel = new EventEmitter(); + get onDidChangeModel(): Event { + return this._onDidChangeModel.event; + } constructor(private readonly container: Container) {} @@ -123,7 +150,7 @@ export class AIProviderService implements Disposable { return this._provider?.id; } - private getConfiguredModel(): { provider: AIProviders; model: string } | undefined { + private getConfiguredModel(): AIModelDescriptor | undefined { const qualifiedModelId = configuration.get('ai.model') ?? undefined; if (qualifiedModelId != null) { let [providerId, modelId] = qualifiedModelId.split(':') as [AIProviders, string]; @@ -145,12 +172,23 @@ export class AIProviderService implements Disposable { } async getModels(): Promise { - const providers = [..._supportedProviderTypes.values()].map(p => new p(this.container)); - const models = await Promise.allSettled(providers.map(p => p.getModels())); - return models.flatMap(m => getSettledValue(m, [])); + const modelResults = await Promise.allSettled( + map(_supportedProviderTypes.values(), t => + t.value.then(async t => { + const p = new t(this.container); + try { + return await p.getModels(); + } finally { + p.dispose(); + } + }), + ), + ); + + return modelResults.flatMap(m => getSettledValue(m, [])); } - async getModel(options?: { force?: boolean; silent?: boolean }): Promise { + async getModel(options?: { force?: boolean; silent?: boolean }, source?: Source): Promise { const cfg = this.getConfiguredModel(); if (!options?.force && cfg?.provider != null && cfg?.model != null) { const model = await this.getOrUpdateModel(cfg.provider, cfg.model); @@ -162,7 +200,21 @@ export class AIProviderService implements Disposable { const pick = await showAIModelPicker(this.container, cfg); if (pick == null) return undefined; - return this.getOrUpdateModel(pick.model); + const model = await this.getOrUpdateModel(pick.model); + + this.container.telemetry.sendEvent( + 'ai/switchModel', + model != null + ? { + 'model.id': model.id, + 'model.provider.id': model.provider.id, + 'model.provider.name': model.provider.name, + } + : { failed: true }, + source, + ); + + return model; } private getOrUpdateModel(model: AIModel): Promise; @@ -184,9 +236,10 @@ export class AIProviderService implements Disposable { if (providerId !== this._provider?.id) { changed = true; + this._providerDisposable?.dispose(); this._provider?.dispose(); - const type = _supportedProviderTypes.get(providerId); + const type = await _supportedProviderTypes.get(providerId)?.value; if (type == null) { this._provider = undefined; this._model = undefined; @@ -195,6 +248,17 @@ export class AIProviderService implements Disposable { } this._provider = new type(this.container); + this._providerDisposable = this._provider?.onDidChange?.( + debounce(async () => { + if (this._model != null) return; + + const model = await this.getModel({ silent: true }); + if (model == null) return; + + this._onDidChangeModel.fire({ model: this._model }); + }, 250), + this, + ); } if (model == null) { @@ -203,7 +267,8 @@ export class AIProviderService implements Disposable { } else { changed = true; - model = (await this._provider.getModels())?.find(m => m.id === modelId); + const models = await this._provider.getModels(); + model = models?.find(m => m.id === modelId); if (model == null) { this._model = undefined; @@ -214,6 +279,8 @@ export class AIProviderService implements Disposable { changed = true; } + this._model = model; + if (changed) { if (isVSCodeAIModel(model)) { await configuration.updateEffective(`ai.model`, 'vscode'); @@ -224,15 +291,15 @@ export class AIProviderService implements Disposable { `${model.provider.id}:${model.id}` as SupportedAIModels, ); } + this._onDidChangeModel.fire({ model: model }); } - this._model = model; return model; } async generateCommitMessage( changesOrRepo: string | string[] | Repository, - sourceContext: { source: Sources }, + source: Source, options?: { cancellation?: CancellationToken; context?: string; @@ -243,7 +310,12 @@ export class AIProviderService implements Disposable { const changes: string | undefined = await this.getChanges(changesOrRepo); if (changes == null) return undefined; - const { confirmed, model } = await getModelAndConfirmAIProviderToS('diff', this, this.container.storage); + const { confirmed, model } = await getModelAndConfirmAIProviderToS( + 'diff', + source, + this, + this.container.storage, + ); if (model == null) { options?.generating?.cancel(); return undefined; @@ -256,7 +328,6 @@ export class AIProviderService implements Disposable { 'model.provider.name': model.provider.name, 'retry.count': 0, }; - const source: Parameters[2] = { source: sourceContext.source }; if (!confirmed) { this.container.telemetry.sendEvent('ai/generate', { ...payload, 'failed.reason': 'user-declined' }, source); @@ -315,7 +386,7 @@ export class AIProviderService implements Disposable { async generateDraftMessage( changesOrRepo: string | string[] | Repository, - sourceContext: { source: Sources; type: AIGenerateDraftEventData['draftType'] }, + sourceContext: Source & { type: AIGenerateDraftEventData['draftType'] }, options?: { cancellation?: CancellationToken; context?: string; @@ -327,21 +398,27 @@ export class AIProviderService implements Disposable { const changes: string | undefined = await this.getChanges(changesOrRepo); if (changes == null) return undefined; - const { confirmed, model } = await getModelAndConfirmAIProviderToS('diff', this, this.container.storage); + const { confirmed, model } = await getModelAndConfirmAIProviderToS( + 'diff', + sourceContext, + this, + this.container.storage, + ); if (model == null) { options?.generating?.cancel(); return undefined; } + const { type, ...source } = sourceContext; + const payload: TelemetryEvents['ai/generate'] = { type: 'draftMessage', - draftType: sourceContext.type, + draftType: type, 'model.id': model.id, 'model.provider.id': model.provider.id, 'model.provider.name': model.provider.name, 'retry.count': 0, }; - const source: Parameters[2] = { source: sourceContext.source }; if (!confirmed) { this.container.telemetry.sendEvent('ai/generate', { ...payload, 'failed.reason': 'user-declined' }, source); @@ -398,7 +475,7 @@ export class AIProviderService implements Disposable { async generateStashMessage( changesOrRepo: string | string[] | Repository, - sourceContext: { source: Sources }, + source: Source, options?: { cancellation?: CancellationToken; context?: string; @@ -412,7 +489,12 @@ export class AIProviderService implements Disposable { return undefined; } - const { confirmed, model } = await getModelAndConfirmAIProviderToS('diff', this, this.container.storage); + const { confirmed, model } = await getModelAndConfirmAIProviderToS( + 'diff', + source, + this, + this.container.storage, + ); if (model == null) { options?.generating?.cancel(); return undefined; @@ -425,7 +507,6 @@ export class AIProviderService implements Disposable { 'model.provider.name': model.provider.name, 'retry.count': 0, }; - const source: Parameters[2] = { source: sourceContext.source }; if (!confirmed) { this.container.telemetry.sendEvent('ai/generate', { ...payload, 'failed.reason': 'user-declined' }, source); @@ -484,10 +565,15 @@ export class AIProviderService implements Disposable { async generateChangelog( changes: Lazy>, - sourceContext: { source: Sources }, + source: Source, options?: { cancellation?: CancellationToken; progress?: ProgressOptions }, ): Promise { - const { confirmed, model } = await getModelAndConfirmAIProviderToS('data', this, this.container.storage); + const { confirmed, model } = await getModelAndConfirmAIProviderToS( + 'data', + source, + this, + this.container.storage, + ); if (model == null) return undefined; const payload: TelemetryEvents['ai/generate'] = { @@ -497,7 +583,6 @@ export class AIProviderService implements Disposable { 'model.provider.name': model.provider.name, 'retry.count': 0, }; - const source: Parameters[2] = { source: sourceContext.source }; if (!confirmed) { this.container.telemetry.sendEvent('ai/generate', { ...payload, 'failed.reason': 'user-declined' }, source); @@ -575,24 +660,30 @@ export class AIProviderService implements Disposable { async explainCommit( commitOrRevision: GitRevisionReference | GitCommit, - sourceContext: { source: Sources; type: TelemetryEvents['ai/explain']['changeType'] }, + sourceContext: Source & { type: TelemetryEvents['ai/explain']['changeType'] }, options?: { cancellation?: CancellationToken; progress?: ProgressOptions }, ): Promise { const diff = await this.container.git.diff(commitOrRevision.repoPath).getDiff?.(commitOrRevision.ref); if (!diff?.contents) throw new Error('No changes found to explain.'); - const { confirmed, model } = await getModelAndConfirmAIProviderToS('diff', this, this.container.storage); + const { confirmed, model } = await getModelAndConfirmAIProviderToS( + 'diff', + sourceContext, + this, + this.container.storage, + ); if (model == null) return undefined; + const { type, ...source } = sourceContext; + const payload: TelemetryEvents['ai/explain'] = { type: 'change', - changeType: sourceContext.type, + changeType: type, 'model.id': model.id, 'model.provider.id': model.provider.id, 'model.provider.name': model.provider.name, 'retry.count': 0, }; - const source: Parameters[2] = { source: sourceContext.source }; if (!confirmed) { this.container.telemetry.sendEvent('ai/explain', { ...payload, 'failed.reason': 'user-declined' }, source); @@ -706,17 +797,18 @@ export class AIProviderService implements Disposable { return _supportedProviderTypes.has(provider as AIProviders); } - switchModel(): Promise { - return this.getModel({ force: true }); + switchModel(source?: Source): Promise { + return this.getModel({ force: true }, source); } } async function getModelAndConfirmAIProviderToS( confirmationType: 'data' | 'diff', + source: Source, service: AIProviderService, storage: Storage, ): Promise<{ confirmed: boolean; model: AIModel | undefined }> { - let model = await service.getModel(); + let model = await service.getModel(undefined, source); while (true) { if (model == null) return { confirmed: false, model: model }; @@ -746,7 +838,7 @@ async function getModelAndConfirmAIProviderToS( ); if (result === switchModel) { - model = await service.switchModel(); + model = await service.switchModel(source); continue; } @@ -894,3 +986,7 @@ export function getValidatedTemperature(modelTemperature?: number | null): numbe if (modelTemperature != null) return modelTemperature; return Math.max(0, Math.min(configuration.get('ai.modelOptions.temperature'), 2)); } + +function isVSCodeAIModel(model: AIModel): model is AIModel<'vscode', VSCodeAIModels> { + return model.provider.id === 'vscode'; +} diff --git a/src/ai/vscodeProvider.ts b/src/ai/vscodeProvider.ts index fa4e7ab751e18..eff50b5e4ee41 100644 --- a/src/ai/vscodeProvider.ts +++ b/src/ai/vscodeProvider.ts @@ -1,6 +1,5 @@ -import type { CancellationToken, LanguageModelChat, LanguageModelChatSelector } from 'vscode'; -import { CancellationTokenSource, LanguageModelChatMessage, lm } from 'vscode'; -import type { VSCodeAIModels } from '../constants.ai'; +import type { CancellationToken, Disposable, Event, LanguageModelChat, LanguageModelChatSelector } from 'vscode'; +import { CancellationTokenSource, EventEmitter, LanguageModelChatMessage, lm } from 'vscode'; import type { TelemetryEvents } from '../constants.telemetry'; import type { Container } from '../container'; import { configuration } from '../system/-webview/configuration'; @@ -26,10 +25,6 @@ const provider = { id: 'vscode', name: 'VS Code Provided' } as const; type VSCodeAIModel = AIModel & { vendor: string; selector: LanguageModelChatSelector }; -export function isVSCodeAIModel(model: AIModel): model is AIModel { - return model.provider.id === provider.id; -} - const accessJustification = 'GitLens leverages Copilot for AI-powered features to improve your workflow and development experience.'; @@ -41,9 +36,20 @@ export class VSCodeAIProvider implements AIProvider { return this._name ?? provider.name; } - constructor(private readonly container: Container) {} + private _onDidChange = new EventEmitter(); + get onDidChange(): Event { + return this._onDidChange.event; + } + + private readonly _disposable: Disposable; - dispose(): void {} + constructor(private readonly container: Container) { + this._disposable = lm.onDidChangeChatModels(() => this._onDidChange.fire()); + } + + dispose(): void { + this._disposable.dispose(); + } async getModels(): Promise[]> { const models = await lm.selectChatModels(); diff --git a/src/ai/xaiProvider.ts b/src/ai/xaiProvider.ts index 12c2af4e1c7d3..1ab46feeda5f9 100644 --- a/src/ai/xaiProvider.ts +++ b/src/ai/xaiProvider.ts @@ -3,8 +3,8 @@ import { OpenAICompatibleProvider } from './openAICompatibleProvider'; const provider = { id: 'xai', name: 'xAI' } as const; -type xAIModel = AIModel; -const models: xAIModel[] = [ +type XAIModel = AIModel; +const models: XAIModel[] = [ { id: 'grok-beta', name: 'Grok Beta', @@ -14,7 +14,7 @@ const models: xAIModel[] = [ }, ]; -export class xAIProvider extends OpenAICompatibleProvider { +export class XAIProvider extends OpenAICompatibleProvider { readonly id = provider.id; readonly name = provider.name; protected readonly config = { diff --git a/src/commands/generateCommitMessage.ts b/src/commands/generateCommitMessage.ts index a1e7169a27ee7..9717581b45725 100644 --- a/src/commands/generateCommitMessage.ts +++ b/src/commands/generateCommitMessage.ts @@ -55,9 +55,7 @@ export class GenerateCommitMessageCommand extends ActiveEditorCommand { try { const currentMessage = scmRepo.inputBox.value; - const message = await ( - await this.container.ai - )?.generateCommitMessage( + const message = await this.container.ai.generateCommitMessage( repository, { source: args?.source ?? 'commandPalette' }, { diff --git a/src/commands/git/stash.ts b/src/commands/git/stash.ts index 4c29debdc66fc..4beebd858a811 100644 --- a/src/commands/git/stash.ts +++ b/src/commands/git/stash.ts @@ -671,9 +671,11 @@ export class StashGitCommand extends QuickCommand { }, ); - const result = await ( - await this.container.ai - )?.generateStashMessage(diff.contents, { source: 'quick-wizard' }, { generating: generating }); + const result = await this.container.ai.generateStashMessage( + diff.contents, + { source: 'quick-wizard' }, + { generating: generating }, + ); input.validationMessage = undefined; diff --git a/src/commands/resets.ts b/src/commands/resets.ts index d09e37aa51202..bedb95f8edc35 100644 --- a/src/commands/resets.ts +++ b/src/commands/resets.ts @@ -169,7 +169,7 @@ export class ResetCommand extends GlCommandBase { break; case 'ai': - await (await this.container.ai)?.reset(true); + await this.container.ai.reset(true); break; case 'avatars': @@ -218,6 +218,6 @@ export class ResetAIKeyCommand extends GlCommandBase { } async execute(): Promise { - await (await this.container.ai)?.reset(); + await this.container.ai.reset(); } } diff --git a/src/commands/switchAIModel.ts b/src/commands/switchAIModel.ts index eb907197e6bf9..deefbc21cd153 100644 --- a/src/commands/switchAIModel.ts +++ b/src/commands/switchAIModel.ts @@ -1,3 +1,4 @@ +import type { Source } from '../constants.telemetry'; import type { Container } from '../container'; import { command } from '../system/-webview/command'; import { GlCommandBase } from './commandBase'; @@ -8,7 +9,7 @@ export class SwitchAIModelCommand extends GlCommandBase { super('gitlens.switchAIModel'); } - async execute(): Promise { - await (await this.container.ai)?.switchModel(); + async execute(source?: Source): Promise { + await this.container.ai.switchModel(source); } } diff --git a/src/constants.telemetry.ts b/src/constants.telemetry.ts index 127c92bd4d0d9..d4e007b690744 100644 --- a/src/constants.telemetry.ts +++ b/src/constants.telemetry.ts @@ -57,28 +57,31 @@ export interface TelemetryEvents extends WebviewShowAbortedEvents, WebviewShownE /** Sent when GitLens is activated */ activate: ActivateEvent; - /** Sent when explaining changes from wip, commits, stashes, patches,etc. */ + /** Sent when explaining changes from wip, commits, stashes, patches, etc. */ 'ai/explain': AIExplainEvent; /** Sent when generating summaries from commits, stashes, patches, etc. */ 'ai/generate': AIGenerateEvent; - /** Sent when connecting to one or more cloud-based integrations*/ + /** Sent when switching ai models */ + 'ai/switchModel': AISwitchModelEvent; + + /** Sent when connecting to one or more cloud-based integrations */ 'cloudIntegrations/connecting': CloudIntegrationsConnectingEvent; - /** Sent when connected to one or more cloud-based integrations from gkdev*/ + /** Sent when connected to one or more cloud-based integrations from gkdev */ 'cloudIntegrations/connected': CloudIntegrationsConnectedEvent; - /** Sent when disconnecting a provider from the api fails*/ + /** Sent when disconnecting a provider from the api fails */ 'cloudIntegrations/disconnect/failed': CloudIntegrationsDisconnectFailedEvent; - /** Sent when getting connected providers from the api fails*/ + /** Sent when getting connected providers from the api fails */ 'cloudIntegrations/getConnections/failed': CloudIntegrationsGetConnectionsFailedEvent; - /** Sent when getting a provider token from the api fails*/ + /** Sent when getting a provider token from the api fails */ 'cloudIntegrations/getConnection/failed': CloudIntegrationsGetConnectionFailedEvent; - /** Sent when refreshing a provider token from the api fails*/ + /** Sent when refreshing a provider token from the api fails */ 'cloudIntegrations/refreshConnection/failed': CloudIntegrationsRefreshConnectionFailedEvent; /** Sent when a cloud-based hosting provider is connected */ @@ -341,6 +344,14 @@ type AIGenerateEvent = | AIGenerateStashEventData | AIGenerateChangelogEventData; +export type AISwitchModelEvent = + | { + 'model.id': string; + 'model.provider.id': AIProviders; + 'model.provider.name': string; + } + | { failed: true }; + interface CloudIntegrationsConnectingEvent { 'integration.ids': string | undefined; } diff --git a/src/container.ts b/src/container.ts index 3e4ec3b2de334..8102a22e8614e 100644 --- a/src/container.ts +++ b/src/container.ts @@ -6,7 +6,7 @@ import { getSupportedRepositoryLocationProvider, getSupportedWorkspacesStorageProvider, } from '@env/providers'; -import type { AIProviderService } from './ai/aiProviderService'; +import { AIProviderService } from './ai/aiProviderService'; import { FileAnnotationController } from './annotations/fileAnnotationController'; import { LineAnnotationController } from './annotations/lineAnnotationController'; import { ActionRunners } from './api/actionRunners'; @@ -343,23 +343,10 @@ export class Container { return this._actionRunners; } - private _ai: Promise | undefined; - get ai(): Promise { + private _ai: AIProviderService | undefined; + get ai(): AIProviderService { if (this._ai == null) { - async function load(this: Container) { - try { - const ai = new ( - await import(/* webpackChunkName: "ai" */ './ai/aiProviderService') - ).AIProviderService(this); - this._disposables.push(ai); - return ai; - } catch (ex) { - Logger.error(ex); - return undefined; - } - } - - this._ai = load.call(this); + this._disposables.push((this._ai = new AIProviderService(this))); } return this._ai; } diff --git a/src/env/node/git/commitMessageProvider.ts b/src/env/node/git/commitMessageProvider.ts index 270f24bec3df0..2748916a95c56 100644 --- a/src/env/node/git/commitMessageProvider.ts +++ b/src/env/node/git/commitMessageProvider.ts @@ -49,9 +49,7 @@ class AICommitMessageProvider implements CommitMessageProvider, Disposable { const currentMessage = repository.inputBox.value; try { - const message = await ( - await this.container.ai - )?.generateCommitMessage( + const message = await this.container.ai.generateCommitMessage( changes, { source: 'scm-input' }, { diff --git a/src/git/utils/-webview/log.utils.ts b/src/git/utils/-webview/log.utils.ts index 580346fce4835..7726b5b15feb0 100644 --- a/src/git/utils/-webview/log.utils.ts +++ b/src/git/utils/-webview/log.utils.ts @@ -42,7 +42,7 @@ export async function getChangesForChangelog(container: Container, log: GitLog): } for (const change of changes) { - change.issues.push( + (change.issues as Mutable).push( ...map(change.links, ([key, link]) => { const issue = issues.get(key); return { diff --git a/src/quickpicks/aiModelPicker.ts b/src/quickpicks/aiModelPicker.ts index e077d7f867743..dbe19fa17e28f 100644 --- a/src/quickpicks/aiModelPicker.ts +++ b/src/quickpicks/aiModelPicker.ts @@ -1,6 +1,6 @@ import type { Disposable, QuickInputButton, QuickPickItem } from 'vscode'; import { QuickPickItemKind, ThemeIcon, window } from 'vscode'; -import type { AIModel } from '../ai/aiProviderService'; +import type { AIModel, AIModelDescriptor } from '../ai/aiProviderService'; import type { AIProviders } from '../constants.ai'; import type { Container } from '../container'; import { executeCommand } from '../system/-webview/command'; @@ -12,9 +12,9 @@ export interface ModelQuickPickItem extends QuickPickItem { export async function showAIModelPicker( container: Container, - current?: { provider: AIProviders; model: string }, + current?: AIModelDescriptor, ): Promise { - const models = (await (await container.ai)?.getModels()) ?? []; + const models = (await container.ai.getModels()) ?? []; const items: ModelQuickPickItem[] = []; diff --git a/src/views/viewCommands.ts b/src/views/viewCommands.ts index 01d96af2da659..8f8d84ce82280 100644 --- a/src/views/viewCommands.ts +++ b/src/views/viewCommands.ts @@ -1759,9 +1759,11 @@ export class ViewCommands implements Disposable { if (!node.is('results-commits')) return; const changes = lazy(() => node.getChangesForChangelog()); - const changelog = await ( - await this.container.ai - )?.generateChangelog(changes, { source: 'view' }, { progress: { location: ProgressLocation.Notification } }); + const changelog = await this.container.ai.generateChangelog( + changes, + { source: 'view' }, + { progress: { location: ProgressLocation.Notification } }, + ); if (changelog == null) return; // open an untitled editor diff --git a/src/webviews/apps/home/stateProvider.ts b/src/webviews/apps/home/stateProvider.ts index 8fa0a4d11a575..197b2a308b5bf 100644 --- a/src/webviews/apps/home/stateProvider.ts +++ b/src/webviews/apps/home/stateProvider.ts @@ -70,6 +70,7 @@ export class HomeStateProvider implements StateProvider { case DidChangeIntegrationsConnections.is(msg): this._state.hasAnyIntegrationConnected = msg.params.hasAnyIntegrationConnected; this._state.integrations = msg.params.integrations; + this._state.ai = msg.params.ai; this._state.timestamp = Date.now(); this.provider.setValue(this._state, true); diff --git a/src/webviews/apps/plus/shared/components/integrations-chip.ts b/src/webviews/apps/plus/shared/components/integrations-chip.ts index 7e7b3c414c836..1a4d27bcb27e8 100644 --- a/src/webviews/apps/plus/shared/components/integrations-chip.ts +++ b/src/webviews/apps/plus/shared/components/integrations-chip.ts @@ -37,7 +37,7 @@ export class GLIntegrationsChip extends LitElement { chipStyles, css` .chip { - gap: 0.8rem; + gap: 0.6rem; padding: 0.2rem 0.4rem 0.4rem 0.4rem; align-items: baseline; } @@ -93,6 +93,11 @@ export class GLIntegrationsChip extends LitElement { align-items: center; } + .integration-row--ai { + border-top: 1px solid var(--color-foreground--25); + padding-top: 0.6rem; + } + .status--disconnected .integration__icon { color: var(--color-foreground--25); } @@ -165,6 +170,10 @@ export class GLIntegrationsChip extends LitElement { return this.hasAccount && this.integrations.some(i => i.connected); } + private get ai() { + return this._state.ai; + } + private get integrations() { return this._state.integrations; } @@ -175,12 +184,13 @@ export class GLIntegrationsChip extends LitElement { override render(): unknown { const anyConnected = this.hasConnectedIntegrations; - const statusFilter = createIconBasedStatusFilter(this.integrations); + const statusFilter = createStatusIconFilter(this.integrations); + return html` ${!anyConnected ? html`Connect` : ''}${this.integrations .filter(statusFilter) - .map(i => this.renderIntegrationStatus(i, anyConnected))} this.renderIntegrationStatus(i))}${this.renderAIStatus()}
@@ -227,32 +237,25 @@ export class GLIntegrationsChip extends LitElement { > ` : this.integrations.map(i => this.renderIntegrationRow(i)) - }
+ }${this.renderAIRow()}
`; } - private renderIntegrationStatus(integration: IntegrationState, anyConnected: boolean) { + private renderIntegrationStatus(integration: IntegrationState) { if (integration.requiresPro && !this.isProAccount) { return html``; } return html`${anyConnected - ? html`` - : nothing}`; + >`; } private renderIntegrationRow(integration: IntegrationState) { @@ -270,7 +273,7 @@ export class GLIntegrationsChip extends LitElement { ${showProBadge ? html` ` : nothing} @@ -282,7 +285,8 @@ export class GLIntegrationsChip extends LitElement { ? html` `; } + + private renderAIStatus() { + return html` + + `; + } + + private renderAIRow() { + const { model } = this.ai; + + const connected = model != null; + const showLock = false; + const showProBadge = false; + const icon = connected ? 'sparkle-filled' : 'sparkle'; // TODO: Provider? + + return html`
+ + + + ${model?.name ?? 'AI'} + ${showProBadge + ? html` ` + : nothing} + + ${model?.provider + ? html`AI Provider: ${model.provider.name}` + : nothing} + + + + +
`; + } } + const featureMap = new Map([ - ['prs', 'Pull Requests'], - ['issues', 'Issues'], + ['prs', 'pull requests'], + ['issues', 'issues'], ]); + function getIntegrationDetails(integration: IntegrationState): string { const features = integration.supports.map(feature => featureMap.get(feature)!); @@ -323,32 +382,19 @@ function getIntegrationDetails(integration: IntegrationState): string { if (features.length === 1) return `Supports ${features[0]}`; const last = features.pop(); - return `Supports ${features.join(', ')} and ${last}`; + return `Supports ${features.join(', ')}, and ${last}`; } -function createIconBasedStatusFilter(integrations: IntegrationState[]) { - const nothing = -1; - const icons = integrations.reduce<{ - [key: string]: undefined | { connectedIndex: number; firstIndex: number }; - }>((icons, i, index) => { - const state = icons[i.icon]; - if (!state) { - icons[i.icon] = { connectedIndex: i.connected ? index : nothing, firstIndex: index }; - } else if (i.connected && state.connectedIndex === nothing) { - state.connectedIndex = index; - } - return icons; - }, {}); - - // This filter returns true or false to allow or decline the integration. - // If nothing is connected with the same icon then allows the first one. - // If any connected then allows the first connected. - return function filter(i: IntegrationState, index: number) { - const state = icons[i.icon]; - if (state === undefined) return true; - if (state.connectedIndex !== nothing) { - return state.connectedIndex === index; +function createStatusIconFilter(integrations: IntegrationState[]) { + const groupedIconMap = new Map(); + + // Group the integrations by icon, and if one is connected + for (const integration of integrations) { + const existing = groupedIconMap.get(integration.icon); + if (!existing || (integration.connected && !existing.connected)) { + groupedIconMap.set(integration.icon, integration); } - return state.firstIndex === index; - }; + } + + return (integration: IntegrationState) => groupedIconMap.get(integration.icon) === integration; } diff --git a/src/webviews/commitDetails/commitDetailsWebview.ts b/src/webviews/commitDetails/commitDetailsWebview.ts index a48db5dfbdcf1..5a1ed15bc007b 100644 --- a/src/webviews/commitDetails/commitDetailsWebview.ts +++ b/src/webviews/commitDetails/commitDetailsWebview.ts @@ -1126,9 +1126,7 @@ export class CommitDetailsWebviewProvider private async explainRequest(requestType: T, msg: IpcCallMessageType) { let params: DidExplainParams; try { - const result = await ( - await this.container.ai - )?.explainCommit( + const result = await this.container.ai.explainCommit( this._context.commit!, { source: 'inspect', type: isStash(this._context.commit) ? 'stash' : 'commit' }, { progress: { location: { viewId: this.host.id } } }, @@ -1162,9 +1160,7 @@ export class CommitDetailsWebviewProvider // const commit = await this.getOrCreateCommitForPatch(patch.gkRepositoryId); // if (commit == null) throw new Error('Unable to find commit'); - const message = await ( - await this.container.ai - )?.generateDraftMessage( + const message = await this.container.ai.generateDraftMessage( repo, { source: 'inspect', type: 'suggested_pr_change' }, { progress: { location: { viewId: this.host.id } } }, diff --git a/src/webviews/home/homeWebview.ts b/src/webviews/home/homeWebview.ts index dbc26395da7a2..0afc83b39a6d9 100644 --- a/src/webviews/home/homeWebview.ts +++ b/src/webviews/home/homeWebview.ts @@ -1,5 +1,6 @@ import type { ConfigurationChangeEvent } from 'vscode'; import { Disposable, Uri, window, workspace } from 'vscode'; +import type { AIModelChangeEvent } from '../../ai/aiProviderService'; import type { CreatePullRequestActionContext } from '../../api/gitlens'; import type { EnrichedAutolink } from '../../autolinks/models/autolinks'; import { getAvatarUriFromGravatarEmail } from '../../avatars'; @@ -159,10 +160,11 @@ export class HomeWebviewProvider implements WebviewProvider { - const [subResult, integrationResult] = await Promise.allSettled([ + const [subResult, integrationResult, aiModelResult] = await Promise.allSettled([ this.getSubscriptionState(subscription), this.getIntegrationStates(true), + this.container.ai.getModel({ silent: true }), ]); if (subResult.status === 'rejected') { @@ -682,6 +689,7 @@ export class HomeWebviewProvider implements WebviewProvider i.connected); + const ai = { model: getSettledValue(aiModelResult) }; return { ...this.host.baseWebviewState, @@ -695,6 +703,7 @@ export class HomeWebviewProvider implements WebviewProvider i.connected); + const ai = { model: getSettledValue(aiModelResult) }; + if (anyConnected) { this.onCollapseSection({ section: 'integrationBanner', @@ -1149,6 +1165,7 @@ export class HomeWebviewProvider implements WebviewProvider( scope, diff --git a/src/webviews/plus/patchDetails/patchDetailsWebview.ts b/src/webviews/plus/patchDetails/patchDetailsWebview.ts index d74c3e5ca85cb..ea07903ed352a 100644 --- a/src/webviews/plus/patchDetails/patchDetailsWebview.ts +++ b/src/webviews/plus/patchDetails/patchDetailsWebview.ts @@ -825,9 +825,7 @@ export class PatchDetailsWebviewProvider const commit = await this.getOrCreateCommitForPatch(patch.gkRepositoryId); if (commit == null) throw new Error('Unable to find commit'); - const result = await ( - await this.container.ai - )?.explainCommit( + const result = await this.container.ai.explainCommit( commit, { source: 'patchDetails', type: `draft-${this._context.draft.type}` }, { progress: { location: { viewId: this.host.id } } }, @@ -869,9 +867,7 @@ export class PatchDetailsWebviewProvider // const commit = await this.getOrCreateCommitForPatch(patch.gkRepositoryId); // if (commit == null) throw new Error('Unable to find commit'); - const message = await ( - await this.container.ai - )?.generateDraftMessage( + const message = await this.container.ai.generateDraftMessage( repo, { source: 'patchDetails', type: 'patch' }, { progress: { location: { viewId: this.host.id } } },