Skip to content

Commit 377ad9c

Browse files
committed
Improves prompts & template management
- Adds template versioning support - Adds conditional body rendering to prevent empty newlines - Improves prompt parsing for more reliable results - Improves error handling for missing templates
1 parent 9bfeeac commit 377ad9c

File tree

6 files changed

+200
-197
lines changed

6 files changed

+200
-197
lines changed

src/commands/generateCommitMessage.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ export class GenerateCommitMessageCommand extends ActiveEditorCommand {
7272
if (result == null) return;
7373

7474
void executeCoreCommand('workbench.view.scm');
75-
scmRepo.inputBox.value = `${currentMessage ? `${currentMessage}\n\n` : ''}${result.parsed.summary}\n\n${
76-
result.parsed.body
75+
scmRepo.inputBox.value = `${currentMessage ? `${currentMessage}\n\n` : ''}${result.parsed.summary}${
76+
result.parsed.body ? `\n\n${result.parsed.body}` : ''
7777
}`;
7878
} catch (ex) {
7979
Logger.error(ex, 'GenerateCommitMessageCommand');

src/env/node/git/commitMessageProvider.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ class AICommitMessageProvider implements CommitMessageProvider, Disposable {
6363
);
6464

6565
if (result == null) return;
66-
return `${currentMessage ? `${currentMessage}\n\n` : ''}${result.parsed.summary}\n\n${result.parsed.body}`;
66+
return `${currentMessage ? `${currentMessage}\n\n` : ''}${result.parsed.summary}${
67+
result.parsed.body ? `\n\n${result.parsed.body}` : ''
68+
}`;
6769
} catch (ex) {
6870
Logger.error(ex, scope);
6971

src/plus/ai/aiProviderService.ts

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ import type {
5757
AIProviderDescriptorWithConfiguration,
5858
AIProviderDescriptorWithType,
5959
} from './models/model';
60-
import type { PromptTemplate, PromptTemplateContext, PromptTemplateType } from './models/promptTemplates';
60+
import type {
61+
PromptTemplate,
62+
PromptTemplateContext,
63+
PromptTemplateId,
64+
PromptTemplateType,
65+
} from './models/promptTemplates';
6166
import type { AIChatMessage, AIProvider, AIRequestResult } from './models/provider';
6267
import { getLocalPromptTemplate, resolvePrompt } from './utils/-webview/prompt.utils';
6368

@@ -201,9 +206,8 @@ export class AIProviderService implements Disposable {
201206

202207
private readonly _disposable: Disposable;
203208
private _model: AIModel | undefined;
204-
private readonly _promptTemplates = new PromiseCache<PromptTemplateType, PromptTemplate>({
209+
private readonly _promptTemplates = new PromiseCache<PromptTemplateId, PromptTemplate | undefined>({
205210
createTTL: 12 * 60 * 60 * 1000, // 12 hours
206-
expireOnError: true,
207211
});
208212
private _provider: AIProvider | undefined;
209213
private _providerDisposable: Disposable | undefined;
@@ -1114,71 +1118,81 @@ export class AIProviderService implements Disposable {
11141118
}
11151119

11161120
private async getPrompt<T extends PromptTemplateType>(
1117-
template: T,
1121+
templateType: T,
11181122
model: AIModel,
11191123
context: PromptTemplateContext<T>,
11201124
maxInputTokens: number,
11211125
retries: number,
11221126
reporting: TelemetryEvents['ai/generate' | 'ai/explain'],
11231127
): Promise<{ prompt: string; truncated: boolean }> {
1124-
const promptTemplate = await this.getPromptTemplate(template, model);
1128+
const promptTemplate = await this.getPromptTemplate(templateType, model);
11251129
if (promptTemplate == null) {
11261130
debugger;
1127-
throw new Error(`No prompt template found for ${template}`);
1131+
throw new Error(`No prompt template found for ${templateType}`);
11281132
}
11291133

11301134
const result = await resolvePrompt(model, promptTemplate, context, maxInputTokens, retries, reporting);
11311135
return result;
11321136
}
11331137

11341138
private async getPromptTemplate<T extends PromptTemplateType>(
1135-
template: T,
1139+
templateType: T,
11361140
model: AIModel,
11371141
): Promise<PromptTemplate | undefined> {
1138-
if ((await this.container.subscription.getSubscription()).account) {
1139-
const scope = getLogScope();
1142+
const scope = getLogScope();
1143+
1144+
const template = getLocalPromptTemplate(templateType, model);
1145+
const templateId = template?.id ?? templateType;
1146+
1147+
return this._promptTemplates.get(templateId, async cancellable => {
1148+
if (!(await this.container.subscription.getSubscription()).account) {
1149+
return template;
1150+
}
11401151

11411152
try {
1142-
return await this._promptTemplates.get(template, async () => {
1143-
const url = this.container.urls.getGkAIApiUrl(`templates/message-prompt/${template}`);
1144-
const rsp = await fetch(url, {
1145-
headers: await this.connection.getGkHeaders(undefined, undefined, {
1146-
Accept: 'application/json',
1147-
}),
1148-
});
1149-
if (!rsp.ok) {
1150-
throw new Error(`Getting prompt template (${url}) failed: ${rsp.status} (${rsp.statusText})`);
1153+
const url = this.container.urls.getGkAIApiUrl(`templates/message-prompt/${templateId}`);
1154+
const rsp = await fetch(url, {
1155+
headers: await this.connection.getGkHeaders(undefined, undefined, { Accept: 'application/json' }),
1156+
});
1157+
if (!rsp.ok) {
1158+
if (rsp.status === 404) {
1159+
Logger.warn(
1160+
scope,
1161+
`${rsp.status} (${rsp.statusText}): Failed to get prompt template '${templateId}' (${url})`,
1162+
);
1163+
return template;
11511164
}
11521165

1153-
interface PromptResponse {
1154-
data: {
1155-
id: string;
1156-
template: string;
1157-
variables: string[];
1158-
};
1159-
error?: null;
1160-
}
1166+
if (rsp.status === 401) throw new AuthenticationRequiredError();
1167+
throw new Error(
1168+
`${rsp.status} (${rsp.statusText}): Failed to get prompt template '${templateId}' (${url})`,
1169+
);
1170+
}
11611171

1162-
const result: PromptResponse = (await rsp.json()) as PromptResponse;
1163-
if (result.error != null) {
1164-
throw new Error(`Getting prompt template (${url}) failed: ${String(result.error)}`);
1165-
}
1172+
interface PromptResponse {
1173+
data: { id: string; template: string; variables: string[] };
1174+
error?: null;
1175+
}
11661176

1167-
return {
1168-
id: result.data.id,
1169-
template: result.data.template,
1170-
variables: result.data.variables,
1171-
};
1172-
});
1177+
const result: PromptResponse = (await rsp.json()) as PromptResponse;
1178+
if (result.error != null) {
1179+
throw new Error(`Failed to get prompt template '${templateId}' (${url}). ${String(result.error)}`);
1180+
}
1181+
1182+
return {
1183+
id: result.data.id as PromptTemplateId<T>,
1184+
template: result.data.template,
1185+
variables: result.data.variables as (keyof PromptTemplateContext<T>)[],
1186+
} satisfies PromptTemplate<T>;
11731187
} catch (ex) {
1188+
cancellable.cancel();
11741189
if (!(ex instanceof AuthenticationRequiredError)) {
11751190
debugger;
1176-
Logger.error(ex, scope, `Unable to get prompt template for '${template}'`);
1191+
Logger.error(ex, scope, String(ex));
11771192
}
1193+
return template;
11781194
}
1179-
}
1180-
1181-
return getLocalPromptTemplate(template, model);
1195+
});
11821196
}
11831197

11841198
async reset(all?: boolean): Promise<void> {
@@ -1292,32 +1306,42 @@ async function showConfirmAIProviderToS(storage: Storage): Promise<boolean> {
12921306

12931307
function parseSummarizeResult(result: string): NonNullable<AISummarizeResult['parsed']> {
12941308
result = result.trim();
1295-
let summary = result.match(/<summary>\s?([\s\S]*?)\s?(<\/summary>|$)/)?.[1]?.trim() ?? '';
1296-
let body = result.match(/<body>\s?([\s\S]*?)\s?(<\/body>|$)/)?.[1]?.trim() ?? '';
1309+
const summary = result.match(/<summary>([\s\S]*?)(?:<\/summary>|$)/)?.[1]?.trim() ?? undefined;
1310+
if (summary != null) {
1311+
result = result.replace(/<summary>[\s\S]*?(?:<\/summary>|$)/, '').trim();
1312+
}
1313+
1314+
let body = result.match(/<body>([\s\S]*?)(?:<\/body>|$)/)?.[1]?.trim() ?? undefined;
1315+
if (body != null) {
1316+
result = result.replace(/<body>[\s\S]*?(?:<\/body>|$)/, '').trim();
1317+
}
1318+
1319+
// Check for self-closing body tag
1320+
if (body == null && result.includes('<body/>')) {
1321+
body = '';
1322+
}
1323+
1324+
// If both tags are present, return them
1325+
if (summary != null && body != null) return { summary: summary, body: body };
12971326

12981327
// If both tags are missing, split the result
1299-
if (!summary && !body) {
1300-
return splitMessageIntoSummaryAndBody(result);
1328+
if (summary == null && body == null) return splitMessageIntoSummaryAndBody(result);
1329+
1330+
// If only summary tag is present, use any remaining text as the body
1331+
if (summary && body == null) {
1332+
return result ? { summary: summary, body: result } : splitMessageIntoSummaryAndBody(summary);
13011333
}
13021334

1303-
if (summary && !body) {
1304-
// If only summary tag is present, use the remaining text as the body
1305-
body = result.replace(/<summary>[\s\S]*?<\/summary>/, '')?.trim() ?? '';
1306-
if (!body) {
1307-
return splitMessageIntoSummaryAndBody(summary);
1308-
}
1309-
} else if (!summary && body) {
1310-
// If only body tag is present, use the remaining text as the summary
1311-
summary = result.replace(/<body>[\s\S]*?<\/body>/, '').trim() ?? '';
1312-
if (!summary) {
1313-
return splitMessageIntoSummaryAndBody(body);
1314-
}
1335+
// If only body tag is present, use the remaining text as the summary
1336+
if (summary == null && body) {
1337+
return result ? { summary: result, body: body } : splitMessageIntoSummaryAndBody(body);
13151338
}
13161339

1317-
return { summary: summary, body: body };
1340+
return { summary: summary ?? '', body: body ?? '' };
13181341
}
13191342

13201343
function splitMessageIntoSummaryAndBody(message: string): NonNullable<AISummarizeResult['parsed']> {
1344+
message = message.replace(/```([\s\S]*?)```/, '$1').trim();
13211345
const index = message.indexOf('\n');
13221346
if (index === -1) return { summary: message, body: '' };
13231347

src/plus/ai/models/promptTemplates.ts

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
export interface PromptTemplate {
2-
readonly id?: string;
1+
export interface PromptTemplate<T extends PromptTemplateType = PromptTemplateType> {
2+
readonly id: PromptTemplateId<T>;
33
readonly template: string;
4-
readonly variables: string[];
4+
readonly variables: (keyof PromptTemplateContext<T>)[];
55
}
66

77
interface ChangelogPromptTemplateContext {
@@ -47,16 +47,21 @@ export type PromptTemplateType =
4747
| `generate-create-${'cloudPatch' | 'codeSuggestion' | 'pullRequest'}`
4848
| 'explain-changes';
4949

50+
type PromptTemplateVersions = '' | '_v2';
51+
52+
export type PromptTemplateId<T extends PromptTemplateType = PromptTemplateType> = `${T}${PromptTemplateVersions}`;
53+
54+
// prettier-ignore
5055
export type PromptTemplateContext<T extends PromptTemplateType> = T extends 'generate-commitMessage'
5156
? CommitMessagePromptTemplateContext
5257
: T extends 'generate-stashMessage'
53-
? StashMessagePromptTemplateContext
54-
: T extends 'generate-create-cloudPatch' | 'generate-create-codeSuggestion'
55-
? CreateDraftPromptTemplateContext
56-
: T extends 'generate-create-pullRequest'
57-
? CreatePullRequestPromptTemplateContext
58-
: T extends 'generate-changelog'
59-
? ChangelogPromptTemplateContext
60-
: T extends 'explain-changes'
61-
? ExplainChangesPromptTemplateContext
62-
: never;
58+
? StashMessagePromptTemplateContext
59+
: T extends 'generate-create-cloudPatch' | 'generate-create-codeSuggestion'
60+
? CreateDraftPromptTemplateContext
61+
: T extends 'generate-create-pullRequest'
62+
? CreatePullRequestPromptTemplateContext
63+
: T extends 'generate-changelog'
64+
? ChangelogPromptTemplateContext
65+
: T extends 'explain-changes'
66+
? ExplainChangesPromptTemplateContext
67+
: never;

0 commit comments

Comments
 (0)