Skip to content

Commit b2fb0c1

Browse files
committed
Implement robust AI-assisted rebase generation with validation and retry logic
(#4395, #4430)
1 parent 72d763a commit b2fb0c1

File tree

1 file changed

+270
-35
lines changed

1 file changed

+270
-35
lines changed

src/plus/ai/aiProviderService.ts

Lines changed: 270 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,18 @@ export class AIProviderService implements Disposable {
981981
return result === 'cancelled' ? result : result != null ? { ...result } : undefined;
982982
}
983983

984+
/**
985+
* Generates a rebase using AI to organize code changes into logical commits.
986+
*
987+
* This method includes automatic retry logic that validates the AI response and
988+
* continues the conversation if the response has issues like:
989+
* - Missing hunks that were in the original diff
990+
* - Extra hunks that weren't in the original diff
991+
* - Duplicate hunks used multiple times
992+
*
993+
* The method will retry up to 3 times, providing specific feedback to the AI
994+
* about what was wrong with the previous response.
995+
*/
984996
async generateRebase(
985997
repo: Repository,
986998
baseRef: string,
@@ -1030,6 +1042,121 @@ export class AIProviderService implements Disposable {
10301042
}
10311043
}
10321044

1045+
const rq = await this.sendRebaseRequestWithRetry(repo, baseRef, headRef, source, result, options);
1046+
1047+
if (rq === 'cancelled') return rq;
1048+
1049+
if (rq == null) return undefined;
1050+
1051+
return {
1052+
...rq,
1053+
...result,
1054+
};
1055+
}
1056+
1057+
private async sendRebaseRequestWithRetry(
1058+
repo: Repository,
1059+
baseRef: string,
1060+
headRef: string,
1061+
source: Source,
1062+
result: Mutable<AIRebaseResult>,
1063+
options?: {
1064+
cancellation?: CancellationToken;
1065+
context?: string;
1066+
generating?: Deferred<AIModel>;
1067+
progress?: ProgressOptions;
1068+
generateCommits?: boolean;
1069+
},
1070+
): Promise<AIRequestResult | 'cancelled' | undefined> {
1071+
let conversationMessages: AIChatMessage[] = [];
1072+
let attempt = 0;
1073+
const maxAttempts = 4;
1074+
1075+
// First attempt - setup diff and hunk map
1076+
const firstAttemptResult = await this.sendRebaseFirstAttempt(repo, baseRef, headRef, source, result, options);
1077+
1078+
if (firstAttemptResult === 'cancelled' || firstAttemptResult == null) {
1079+
return firstAttemptResult;
1080+
}
1081+
1082+
conversationMessages = firstAttemptResult.conversationMessages;
1083+
let rq = firstAttemptResult.response;
1084+
1085+
while (attempt < maxAttempts) {
1086+
const validationResult = this.validateRebaseResponse(rq, result.hunkMap, options);
1087+
if (validationResult.isValid) {
1088+
result.commits = validationResult.commits;
1089+
return rq;
1090+
}
1091+
1092+
Logger.warn(
1093+
undefined,
1094+
'AIProviderService',
1095+
'sendRebaseRequestWithRetry',
1096+
`Validation failed on attempt ${attempt + 1}: ${validationResult.errorMessage}`,
1097+
);
1098+
1099+
// If this was the last attempt, throw the error
1100+
if (attempt === maxAttempts - 1) {
1101+
throw new Error(validationResult.errorMessage);
1102+
}
1103+
1104+
// Prepare retry message for conversation
1105+
conversationMessages.push(
1106+
{ role: 'assistant', content: rq.content },
1107+
{ role: 'user', content: validationResult.retryPrompt },
1108+
);
1109+
1110+
attempt++;
1111+
1112+
// Send retry request
1113+
const currentAttempt = attempt;
1114+
const retryResult = await this.sendRequest(
1115+
'generate-rebase',
1116+
async () => Promise.resolve(conversationMessages),
1117+
m =>
1118+
`Generating ${options?.generateCommits ? 'commits' : 'rebase'} with ${m.name}... (attempt ${
1119+
currentAttempt + 1
1120+
})`,
1121+
source,
1122+
m => ({
1123+
key: 'ai/generate',
1124+
data: {
1125+
type: 'rebase',
1126+
'model.id': m.id,
1127+
'model.provider.id': m.provider.id,
1128+
'model.provider.name': m.provider.name,
1129+
'retry.count': currentAttempt,
1130+
},
1131+
}),
1132+
options,
1133+
);
1134+
1135+
if (retryResult === 'cancelled' || retryResult == null) {
1136+
return retryResult;
1137+
}
1138+
1139+
rq = retryResult;
1140+
}
1141+
1142+
return undefined;
1143+
}
1144+
1145+
private async sendRebaseFirstAttempt(
1146+
repo: Repository,
1147+
baseRef: string,
1148+
headRef: string,
1149+
source: Source,
1150+
result: Mutable<AIRebaseResult>,
1151+
options?: {
1152+
cancellation?: CancellationToken;
1153+
context?: string;
1154+
generating?: Deferred<AIModel>;
1155+
progress?: ProgressOptions;
1156+
generateCommits?: boolean;
1157+
},
1158+
): Promise<{ response: AIRequestResult; conversationMessages: AIChatMessage[] } | 'cancelled' | undefined> {
1159+
let storedPrompt = '';
10331160
const rq = await this.sendRequest(
10341161
'generate-rebase',
10351162
async (model, reporting, cancellation, maxInputTokens, retries) => {
@@ -1086,6 +1213,9 @@ export class AIProviderService implements Disposable {
10861213
);
10871214
if (cancellation.isCancellationRequested) throw new CancellationError();
10881215

1216+
// Store the prompt for later use in conversation messages
1217+
storedPrompt = prompt;
1218+
10891219
const messages: AIChatMessage[] = [{ role: 'user', content: prompt }];
10901220
return messages;
10911221
},
@@ -1108,47 +1238,141 @@ export class AIProviderService implements Disposable {
11081238

11091239
if (rq == null) return undefined;
11101240

1241+
return {
1242+
response: rq,
1243+
conversationMessages: [{ role: 'user', content: storedPrompt }],
1244+
};
1245+
}
1246+
1247+
private validateRebaseResponse(
1248+
rq: AIRequestResult,
1249+
inputHunkMap: { index: number; hunkHeader: string }[],
1250+
options?: {
1251+
generateCommits?: boolean;
1252+
},
1253+
):
1254+
| { isValid: false; errorMessage: string; retryPrompt: string }
1255+
| { isValid: true; commits: AIRebaseResult['commits'] } {
1256+
// if it is wrapped in markdown, we need to strip it
1257+
const content = rq.content.replace(/^\s*```json\s*/, '').replace(/\s*```$/, '');
1258+
1259+
let commits: AIRebaseResult['commits'];
11111260
try {
1112-
// if it is wrapped in markdown, we need to strip it
1113-
const content = rq.content.replace(/^\s*```json\s*/, '').replace(/\s*```$/, '');
11141261
// Parse the JSON content from the result
1115-
result.commits = JSON.parse(content) as AIRebaseResult['commits'];
1262+
commits = JSON.parse(content) as AIRebaseResult['commits'];
1263+
} catch {
1264+
const errorMessage = `Unable to parse ${options?.generateCommits ? 'commits' : 'rebase'} result`;
1265+
const retryPrompt = dedent(`
1266+
Your previous response could not be parsed as valid JSON. Please ensure your response is a valid JSON array of commits with the correct structure.
1267+
1268+
Here was your previous response:
1269+
${rq.content}
11161270
1117-
const inputHunkIndices = result.hunkMap.map(h => h.index);
1118-
const outputHunkIndices = new Set(result.commits.flatMap(c => c.hunks.map(h => h.hunk)));
1271+
Please provide a valid JSON array of commits following this structure:
1272+
[
1273+
{
1274+
"message": "commit message",
1275+
"explanation": "detailed explanation",
1276+
"hunks": [{"hunk": 1}, {"hunk": 2}]
1277+
}
1278+
]
1279+
`);
1280+
1281+
return {
1282+
isValid: false,
1283+
errorMessage: errorMessage,
1284+
retryPrompt: retryPrompt,
1285+
};
1286+
}
11191287

1120-
// Find any missing or extra hunks
1288+
// Validate the structure and hunk assignments
1289+
try {
1290+
const inputHunkIndices = inputHunkMap.map(h => h.index);
1291+
const allOutputHunks = commits.flatMap(c => c.hunks.map(h => h.hunk));
1292+
const outputHunkIndices = new Map(allOutputHunks.map((hunk, index) => [hunk, index]));
11211293
const missingHunks = inputHunkIndices.filter(i => !outputHunkIndices.has(i));
1122-
const extraHunks = [...outputHunkIndices].filter(i => !inputHunkIndices.includes(i));
1123-
if (missingHunks.length > 0 || extraHunks.length > 0) {
1124-
let hunksMessage = '';
1294+
1295+
if (missingHunks.length > 0 || allOutputHunks.length > inputHunkIndices.length) {
1296+
const errorParts: string[] = [];
1297+
const retryParts: string[] = [];
1298+
11251299
if (missingHunks.length > 0) {
11261300
const pluralize = missingHunks.length > 1 ? 's' : '';
1127-
hunksMessage += ` ${missingHunks.length} missing hunk${pluralize}.`;
1301+
errorParts.push(`${missingHunks.length} missing hunk${pluralize}`);
1302+
retryParts.push(`You missed hunk${pluralize} ${missingHunks.join(', ')} in your response`);
11281303
}
1304+
const extraHunks = [...outputHunkIndices.keys()].filter(i => !inputHunkIndices.includes(i));
11291305
if (extraHunks.length > 0) {
11301306
const pluralize = extraHunks.length > 1 ? 's' : '';
1131-
hunksMessage += ` ${extraHunks.length} extra hunk${pluralize}.`;
1307+
errorParts.push(`${extraHunks.length} extra hunk${pluralize}`);
1308+
retryParts.push(
1309+
`You included hunk${pluralize} ${extraHunks.join(', ')} which ${
1310+
extraHunks.length > 1 ? 'were' : 'was'
1311+
} not in the original diff`,
1312+
);
1313+
}
1314+
const duplicateHunks = allOutputHunks.filter((hunk, index) => outputHunkIndices.get(hunk)! !== index);
1315+
const uniqueDuplicates = [...new Set(duplicateHunks)];
1316+
if (uniqueDuplicates.length > 0) {
1317+
const pluralize = uniqueDuplicates.length > 1 ? 's' : '';
1318+
errorParts.push(`${uniqueDuplicates.length} duplicate hunk${pluralize}`);
1319+
retryParts.push(`You used hunk${pluralize} ${uniqueDuplicates.join(', ')} multiple times`);
11321320
}
11331321

1134-
throw new Error(
1135-
`Invalid response in generating ${
1136-
options?.generateCommits ? 'commits' : 'rebase'
1137-
} result.${hunksMessage} Try again or select a different AI model.`,
1138-
);
1139-
}
1140-
} catch (ex) {
1141-
debugger;
1142-
if (ex?.message?.includes('Invalid response in generating')) {
1143-
throw ex;
1322+
const errorMessage = `Invalid response in generating ${
1323+
options?.generateCommits ? 'commits' : 'rebase'
1324+
} result. ${errorParts.join(', ')}.`;
1325+
1326+
const retryPrompt = dedent(`
1327+
Your previous response had issues: ${retryParts.join(', ')}.
1328+
1329+
Please provide a corrected JSON response that:
1330+
1. Includes ALL hunks from 1 to ${Math.max(...inputHunkIndices)} exactly once
1331+
2. Does not include any hunk numbers outside this range
1332+
3. Does not use any hunk more than once
1333+
1334+
Here was your previous response:
1335+
${rq.content}
1336+
1337+
Please provide the corrected JSON array of commits:
1338+
`);
1339+
1340+
return {
1341+
isValid: false,
1342+
errorMessage: errorMessage,
1343+
retryPrompt: retryPrompt,
1344+
};
11441345
}
1145-
throw new Error(`Unable to parse ${options?.generateCommits ? 'commits' : 'rebase'} result`);
1146-
}
11471346

1148-
return {
1149-
...rq,
1150-
...result,
1151-
};
1347+
// If validation passes, return the commits
1348+
return { isValid: true, commits: commits };
1349+
} catch {
1350+
// Handle any errors during hunk validation (e.g., malformed commit structure)
1351+
const errorMessage = `Invalid commit structure in ${
1352+
options?.generateCommits ? 'commits' : 'rebase'
1353+
} result`;
1354+
const retryPrompt = dedent(`
1355+
Your previous response has an invalid commit structure. Each commit must have "message", "explanation", and "hunks" properties, where "hunks" is an array of objects with "hunk" numbers.
1356+
1357+
Here was your previous response:
1358+
${rq.content}
1359+
1360+
Please provide a valid JSON array of commits following this structure:
1361+
[
1362+
{
1363+
"message": "commit message",
1364+
"explanation": "detailed explanation",
1365+
"hunks": [{"hunk": 1}, {"hunk": 2}]
1366+
}
1367+
]
1368+
`);
1369+
1370+
return {
1371+
isValid: false,
1372+
errorMessage: errorMessage,
1373+
retryPrompt: retryPrompt,
1374+
};
1375+
}
11521376
}
11531377

11541378
private async sendRequest<T extends AIActionType>(
@@ -1654,22 +1878,32 @@ export class AIProviderService implements Disposable {
16541878
const alreadyCompleted = this.container.storage.get(`gk:promo:${userId}:ai:allAccess:dismissed`, false);
16551879
if (notificationShown || alreadyCompleted) return;
16561880

1657-
const hasAdvancedOrHigher = subscription.plan &&
1881+
const hasAdvancedOrHigher =
1882+
subscription.plan &&
16581883
(compareSubscriptionPlans(subscription.plan.actual.id, 'advanced') >= 0 ||
1659-
compareSubscriptionPlans(subscription.plan.effective.id, 'advanced') >= 0);
1884+
compareSubscriptionPlans(subscription.plan.effective.id, 'advanced') >= 0);
16601885

16611886
let body = 'All Access Week - now until July 11th!';
1662-
const detail = hasAdvancedOrHigher ? 'Opt in now to get unlimited GitKraken AI until July 11th!' : 'Opt in now to try all Advanced GitLens features with unlimited GitKraken AI for FREE until July 11th!';
1887+
const detail = hasAdvancedOrHigher
1888+
? 'Opt in now to get unlimited GitKraken AI until July 11th!'
1889+
: 'Opt in now to try all Advanced GitLens features with unlimited GitKraken AI for FREE until July 11th!';
16631890

16641891
if (!usingGkProvider) {
16651892
body += ` ${detail}`;
16661893
}
16671894

1668-
const optInButton: MessageItem = usingGkProvider ? { title: 'Opt in for Unlimited AI' } : { title: 'Opt in and Switch to GitKraken AI' };
1895+
const optInButton: MessageItem = usingGkProvider
1896+
? { title: 'Opt in for Unlimited AI' }
1897+
: { title: 'Opt in and Switch to GitKraken AI' };
16691898
const dismissButton: MessageItem = { title: 'No, Thanks', isCloseAffordance: true };
16701899

16711900
// Show the notification
1672-
const result = await window.showInformationMessage(body, { modal: usingGkProvider, detail: detail }, optInButton, dismissButton);
1901+
const result = await window.showInformationMessage(
1902+
body,
1903+
{ modal: usingGkProvider, detail: detail },
1904+
optInButton,
1905+
dismissButton,
1906+
);
16731907

16741908
// Mark notification as shown regardless of user action
16751909
void this.container.storage.store(`gk:promo:${userId}:ai:allAccess:notified`, true);
@@ -1692,7 +1926,10 @@ export class AIProviderService implements Disposable {
16921926
await configuration.updateEffective('ai.model', 'gitkraken');
16931927
await configuration.updateEffective(`ai.gitkraken.model`, defaultModel.id);
16941928
} else {
1695-
await configuration.updateEffective('ai.model', `gitkraken:${defaultModel.id}` as SupportedAIModels);
1929+
await configuration.updateEffective(
1930+
'ai.model',
1931+
`gitkraken:${defaultModel.id}` as SupportedAIModels,
1932+
);
16961933
}
16971934

16981935
this._onDidChangeModel.fire({ model: defaultModel });
@@ -1701,8 +1938,6 @@ export class AIProviderService implements Disposable {
17011938
}
17021939
}
17031940

1704-
1705-
17061941
async function showConfirmAIProviderToS(storage: Storage): Promise<boolean> {
17071942
const confirmed = storage.get(`confirm:ai:tos`, false) || storage.getWorkspace(`confirm:ai:tos`, false);
17081943
if (confirmed) return true;

0 commit comments

Comments
 (0)