Skip to content

Commit 4a91d23

Browse files
committed
Improves AI error handling and responses
Adds caching for prompt templates
1 parent 7f3dcbc commit 4a91d23

File tree

11 files changed

+192
-100
lines changed

11 files changed

+192
-100
lines changed

src/commands/generateCommitMessage.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,19 @@ export class GenerateCommitMessageCommand extends ActiveEditorCommand {
6161

6262
try {
6363
const currentMessage = scmRepo.inputBox.value;
64-
const message = await this.container.ai.generateCommitMessage(
64+
const result = await this.container.ai.generateCommitMessage(
6565
repository,
6666
{ source: args?.source ?? 'commandPalette' },
6767
{
6868
context: currentMessage,
6969
progress: { location: ProgressLocation.Notification, title: 'Generating commit message...' },
7070
},
7171
);
72-
if (message == null) return;
72+
if (result == null) return;
7373

7474
void executeCoreCommand('workbench.view.scm');
75-
scmRepo.inputBox.value = `${currentMessage ? `${currentMessage}\n\n` : ''}${message.summary}\n\n${
76-
message.body
75+
scmRepo.inputBox.value = `${currentMessage ? `${currentMessage}\n\n` : ''}${result.parsed.summary}\n\n${
76+
result.parsed.body
7777
}`;
7878
} catch (ex) {
7979
Logger.error(ex, 'GenerateCommitMessageCommand');

src/commands/git/stash.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ export class StashGitCommand extends QuickCommand<State> {
683683

684684
input.validationMessage = undefined;
685685

686-
const message = result?.summary;
686+
const message = result?.parsed.summary;
687687
if (message != null) {
688688
state.message = message;
689689
input.value = message;

src/env/node/git/commitMessageProvider.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class AICommitMessageProvider implements CommitMessageProvider, Disposable {
4949

5050
const currentMessage = repository.inputBox.value;
5151
try {
52-
const message = await this.container.ai.generateCommitMessage(
52+
const result = await this.container.ai.generateCommitMessage(
5353
changes,
5454
{ source: 'scm-input' },
5555
{
@@ -62,8 +62,8 @@ class AICommitMessageProvider implements CommitMessageProvider, Disposable {
6262
},
6363
);
6464

65-
if (message == null) return;
66-
return `${currentMessage ? `${currentMessage}\n\n` : ''}${message.summary}\n\n${message.body}`;
65+
if (result == null) return;
66+
return `${currentMessage ? `${currentMessage}\n\n` : ''}${result.parsed.summary}\n\n${result.parsed.body}`;
6767
} catch (ex) {
6868
Logger.error(ex, scope);
6969

src/plus/ai/aiProviderService.ts

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,29 @@ import { getSettledValue } from '../../system/promise';
2424
import type { ServerConnection } from '../gk/serverConnection';
2525
import type { AIActionType, AIModel, AIModelDescriptor } from './models/model';
2626
import type { PromptTemplateContext } from './models/promptTemplates';
27-
import type { AIProvider } from './models/provider';
27+
import type { AIProvider, AIRequestResult } from './models/provider';
2828

2929
export interface AIResult {
30-
readonly summary: string;
31-
readonly body: string;
30+
readonly id?: string;
31+
readonly content: string;
32+
readonly usage?: {
33+
readonly promptTokens?: number;
34+
readonly completionTokens?: number;
35+
readonly totalTokens?: number;
36+
37+
readonly limits?: {
38+
readonly used: number;
39+
readonly limit: number;
40+
readonly resetsOn: Date;
41+
};
42+
};
43+
}
44+
45+
export interface AISummarizeResult extends AIResult {
46+
readonly parsed: {
47+
readonly summary: string;
48+
readonly body: string;
49+
};
3250
}
3351

3452
export interface AIGenerateChangelogChange {
@@ -270,7 +288,7 @@ export class AIProviderService implements Disposable {
270288
commitOrRevision: GitRevisionReference | GitCommit,
271289
sourceContext: Source & { type: TelemetryEvents['ai/explain']['changeType'] },
272290
options?: { cancellation?: CancellationToken; progress?: ProgressOptions },
273-
): Promise<AIResult | undefined> {
291+
): Promise<AISummarizeResult | undefined> {
274292
const diff = await this.container.git.diff(commitOrRevision.repoPath).getDiff?.(commitOrRevision.ref);
275293
if (!diff?.contents) throw new Error('No changes found to explain.');
276294

@@ -308,7 +326,7 @@ export class AIProviderService implements Disposable {
308326
}),
309327
options,
310328
);
311-
return result != null ? parseResult(result) : undefined;
329+
return result != null ? { ...result, parsed: parseSummarizeResult(result.content) } : undefined;
312330
}
313331

314332
async generateCommitMessage(
@@ -320,7 +338,7 @@ export class AIProviderService implements Disposable {
320338
generating?: Deferred<AIModel>;
321339
progress?: ProgressOptions;
322340
},
323-
): Promise<AIResult | undefined> {
341+
): Promise<AISummarizeResult | undefined> {
324342
const changes: string | undefined = await this.getChanges(changesOrRepo);
325343
if (changes == null) return undefined;
326344

@@ -345,7 +363,7 @@ export class AIProviderService implements Disposable {
345363
}),
346364
options,
347365
);
348-
return result != null ? parseResult(result) : undefined;
366+
return result != null ? { ...result, parsed: parseSummarizeResult(result.content) } : undefined;
349367
}
350368

351369
async generateDraftMessage(
@@ -358,7 +376,7 @@ export class AIProviderService implements Disposable {
358376
progress?: ProgressOptions;
359377
codeSuggestion?: boolean;
360378
},
361-
): Promise<AIResult | undefined> {
379+
): Promise<AISummarizeResult | undefined> {
362380
const changes: string | undefined = await this.getChanges(changesOrRepo);
363381
if (changes == null) return undefined;
364382

@@ -392,7 +410,7 @@ export class AIProviderService implements Disposable {
392410
}),
393411
options,
394412
);
395-
return result != null ? parseResult(result) : undefined;
413+
return result != null ? { ...result, parsed: parseSummarizeResult(result.content) } : undefined;
396414
}
397415

398416
async generateStashMessage(
@@ -404,7 +422,7 @@ export class AIProviderService implements Disposable {
404422
generating?: Deferred<AIModel>;
405423
progress?: ProgressOptions;
406424
},
407-
): Promise<AIResult | undefined> {
425+
): Promise<AISummarizeResult | undefined> {
408426
const changes: string | undefined = await this.getChanges(changesOrRepo);
409427
if (changes == null) {
410428
options?.generating?.cancel();
@@ -432,14 +450,14 @@ export class AIProviderService implements Disposable {
432450
}),
433451
options,
434452
);
435-
return result != null ? parseResult(result) : undefined;
453+
return result != null ? { ...result, parsed: parseSummarizeResult(result.content) } : undefined;
436454
}
437455

438456
async generateChangelog(
439457
changes: Lazy<Promise<AIGenerateChangelogChange[]>>,
440458
source: Source,
441459
options?: { cancellation?: CancellationToken; progress?: ProgressOptions },
442-
): Promise<string | undefined> {
460+
): Promise<AIResult | undefined> {
443461
const result = await this.sendRequest(
444462
'generate-changelog',
445463
async () => ({
@@ -460,7 +478,7 @@ export class AIProviderService implements Disposable {
460478
}),
461479
options,
462480
);
463-
return result;
481+
return result != null ? { ...result } : undefined;
464482
}
465483

466484
private async sendRequest<T extends AIActionType>(
@@ -477,7 +495,7 @@ export class AIProviderService implements Disposable {
477495
generating?: Deferred<AIModel>;
478496
progress?: ProgressOptions;
479497
},
480-
): Promise<string | undefined> {
498+
): Promise<AIRequestResult | undefined> {
481499
const { confirmed, model } = await getModelAndConfirmAIProviderToS(
482500
'diff',
483501
source,
@@ -525,7 +543,7 @@ export class AIProviderService implements Disposable {
525543
? window.withProgress({ ...options.progress, title: getProgressTitle(model) }, () => promise)
526544
: promise);
527545

528-
telementry.data['output.length'] = result?.length;
546+
telementry.data['output.length'] = result?.content?.length;
529547
this.container.telemetry.sendEvent(
530548
telementry.key,
531549
{ ...telementry.data, duration: Date.now() - start },
@@ -693,7 +711,7 @@ async function getModelAndConfirmAIProviderToS(
693711
}
694712
}
695713

696-
function parseResult(result: string): AIResult {
714+
function parseSummarizeResult(result: string): NonNullable<AISummarizeResult['parsed']> {
697715
result = result.trim();
698716
let summary = result.match(/<summary>\s?([\s\S]*?)\s?(<\/summary>|$)/)?.[1]?.trim() ?? '';
699717
let body = result.match(/<body>\s?([\s\S]*?)\s?(<\/body>|$)/)?.[1]?.trim() ?? '';
@@ -720,7 +738,7 @@ function parseResult(result: string): AIResult {
720738
return { summary: summary, body: body };
721739
}
722740

723-
function splitMessageIntoSummaryAndBody(message: string): AIResult {
741+
function splitMessageIntoSummaryAndBody(message: string): NonNullable<AISummarizeResult['parsed']> {
724742
const index = message.indexOf('\n');
725743
if (index === -1) return { summary: message, body: '' };
726744

src/plus/ai/gitkrakenProvider.ts

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
import type { CancellationToken } from 'vscode';
2-
import type { Response } from '@env/fetch';
1+
import type { Disposable } from 'vscode';
32
import { fetch } from '@env/fetch';
3+
import type { Container } from '../../container';
4+
import { AuthenticationRequiredError } from '../../errors';
45
import { debug } from '../../system/decorators/log';
56
import { Logger } from '../../system/logger';
67
import { getLogScope } from '../../system/logger.scope';
8+
import { PromiseCache } from '../../system/promiseCache';
9+
import type { ServerConnection } from '../gk/serverConnection';
710
import type { AIActionType, AIModel } from './models/model';
811
import type { PromptTemplate } from './models/promptTemplates';
912
import { OpenAICompatibleProvider } from './openAICompatibleProvider';
@@ -18,16 +21,36 @@ export class GitKrakenProvider extends OpenAICompatibleProvider<typeof provider.
1821
readonly name = provider.name;
1922
protected readonly config = {};
2023

24+
private readonly _disposable: Disposable;
25+
private readonly _promptTemplates = new PromiseCache<AIActionType, PromptTemplate>({
26+
createTTL: 12 * 60 * 60 * 1000, // 12 hours
27+
expireOnError: true,
28+
});
29+
30+
constructor(container: Container, connection: ServerConnection) {
31+
super(container, connection);
32+
33+
this._disposable = this.container.subscription.onDidChange(() => this._promptTemplates.clear());
34+
}
35+
36+
override dispose(): void {
37+
this._disposable.dispose();
38+
}
39+
2140
@debug()
2241
async getModels(): Promise<readonly AIModel<typeof provider.id>[]> {
2342
const scope = getLogScope();
2443

2544
try {
26-
const rsp = await fetch(this.container.urls.getGkAIApiUrl('providers/message-prompt'), {
45+
const url = this.container.urls.getGkAIApiUrl('providers/message-prompt');
46+
const rsp = await fetch(url, {
2747
headers: await this.connection.getGkHeaders(undefined, undefined, {
2848
Accept: 'application/json',
2949
}),
3050
});
51+
if (!rsp.ok) {
52+
throw new Error(`Getting models (${url}) failed: ${rsp.status} (${rsp.statusText})`);
53+
}
3154

3255
interface ModelsResponse {
3356
data: {
@@ -43,27 +66,27 @@ export class GitKrakenProvider extends OpenAICompatibleProvider<typeof provider.
4366
}
4467

4568
const result: ModelsResponse = await rsp.json();
46-
47-
if (result.error == null) {
48-
const models: GitKrakenModel[] = result.data.map(
49-
m =>
50-
({
51-
id: m.modelId,
52-
name: m.modelName,
53-
maxTokens: { input: m.maxInputTokens, output: m.maxOutputTokens },
54-
provider: provider,
55-
default: m.preferred,
56-
temperature: null,
57-
}) satisfies GitKrakenModel,
58-
);
59-
return models;
69+
if (result.error != null) {
70+
throw new Error(`Getting models (${url}) failed: ${String(result.error)}`);
6071
}
6172

62-
debugger;
63-
Logger.error(undefined, scope, `${String(result.error)}: Unable to get models`);
73+
const models: GitKrakenModel[] = result.data.map(
74+
m =>
75+
({
76+
id: m.modelId,
77+
name: m.modelName,
78+
maxTokens: { input: m.maxInputTokens, output: m.maxOutputTokens },
79+
provider: provider,
80+
default: m.preferred,
81+
temperature: null,
82+
}) satisfies GitKrakenModel,
83+
);
84+
return models;
6485
} catch (ex) {
65-
debugger;
66-
Logger.error(ex, scope, `Unable to get models`);
86+
if (!(ex instanceof AuthenticationRequiredError)) {
87+
debugger;
88+
Logger.error(ex, scope, `Unable to get models`);
89+
}
6790
}
6891

6992
return [];
@@ -76,36 +99,43 @@ export class GitKrakenProvider extends OpenAICompatibleProvider<typeof provider.
7699
const scope = getLogScope();
77100

78101
try {
79-
const rsp = await fetch(this.container.urls.getGkAIApiUrl(`templates/message-prompt/${action}`), {
80-
headers: await this.connection.getGkHeaders(undefined, undefined, {
81-
Accept: 'application/json',
82-
}),
83-
});
102+
return await this._promptTemplates.get(action, async () => {
103+
const url = this.container.urls.getGkAIApiUrl(`templates/message-prompt/${action}`);
104+
const rsp = await fetch(url, {
105+
headers: await this.connection.getGkHeaders(undefined, undefined, {
106+
Accept: 'application/json',
107+
}),
108+
});
109+
if (!rsp.ok) {
110+
throw new Error(`Getting prompt template (${url}) failed: ${rsp.status} (${rsp.statusText})`);
111+
}
112+
113+
interface PromptResponse {
114+
data: {
115+
id: string;
116+
template: string;
117+
variables: string[];
118+
};
119+
error?: null;
120+
}
121+
122+
const result: PromptResponse = await rsp.json();
123+
if (result.error != null) {
124+
throw new Error(`Getting prompt template (${url}) failed: ${String(result.error)}`);
125+
}
84126

85-
interface PromptResponse {
86-
data: {
87-
id: string;
88-
template: string;
89-
variables: string[];
90-
};
91-
error?: null;
92-
}
93-
94-
const result: PromptResponse = await rsp.json();
95-
if (result.error == null) {
96127
return {
97128
id: result.data.id,
98129
name: getActionName(action),
99130
template: result.data.template,
100131
variables: result.data.variables,
101132
};
102-
}
103-
104-
debugger;
105-
Logger.error(undefined, scope, `${String(result.error)}: Unable to get prompt template for '${action}'`);
133+
});
106134
} catch (ex) {
107-
debugger;
108-
Logger.error(ex, scope, `Unable to get prompt template for '${action}'`);
135+
if (!(ex instanceof AuthenticationRequiredError)) {
136+
debugger;
137+
Logger.error(ex, scope, `Unable to get prompt template for '${action}'`);
138+
}
109139
}
110140

111141
return super.getPromptTemplate(action, model);
@@ -130,14 +160,4 @@ export class GitKrakenProvider extends OpenAICompatibleProvider<typeof provider.
130160
'GK-Action': action,
131161
});
132162
}
133-
134-
protected override fetchCore<TAction extends AIActionType>(
135-
action: TAction,
136-
model: AIModel<typeof provider.id>,
137-
_apiKey: string,
138-
request: object,
139-
cancellation: CancellationToken | undefined,
140-
): Promise<Response> {
141-
return super.fetchCore(action, model, _apiKey, request, cancellation);
142-
}
143163
}

0 commit comments

Comments
 (0)