Skip to content

Commit f05b901

Browse files
committed
Implement robust AI-assisted rebase generation with validation and retry logic
(#4395, #4430)
1 parent 5992945 commit f05b901

File tree

1 file changed

+251
-27
lines changed

1 file changed

+251
-27
lines changed

src/plus/ai/aiProviderService.ts

Lines changed: 251 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,18 @@ export class AIProviderService implements Disposable {
937937
return result === 'cancelled' ? result : result != null ? { ...result } : undefined;
938938
}
939939

940+
/**
941+
* Generates a rebase using AI to organize code changes into logical commits.
942+
*
943+
* This method includes automatic retry logic that validates the AI response and
944+
* continues the conversation if the response has issues like:
945+
* - Missing hunks that were in the original diff
946+
* - Extra hunks that weren't in the original diff
947+
* - Duplicate hunks used multiple times
948+
*
949+
* The method will retry up to 3 times, providing specific feedback to the AI
950+
* about what was wrong with the previous response.
951+
*/
940952
async generateRebase(
941953
repo: Repository,
942954
baseRef: string,
@@ -986,6 +998,121 @@ export class AIProviderService implements Disposable {
986998
}
987999
}
9881000

1001+
const rq = await this.sendRebaseRequestWithRetry(repo, baseRef, headRef, source, result, options);
1002+
1003+
if (rq === 'cancelled') return rq;
1004+
1005+
if (rq == null) return undefined;
1006+
1007+
return {
1008+
...rq,
1009+
...result,
1010+
};
1011+
}
1012+
1013+
private async sendRebaseRequestWithRetry(
1014+
repo: Repository,
1015+
baseRef: string,
1016+
headRef: string,
1017+
source: Source,
1018+
result: Mutable<AIRebaseResult>,
1019+
options?: {
1020+
cancellation?: CancellationToken;
1021+
context?: string;
1022+
generating?: Deferred<AIModel>;
1023+
progress?: ProgressOptions;
1024+
generateCommits?: boolean;
1025+
},
1026+
): Promise<AIRequestResult | 'cancelled' | undefined> {
1027+
let conversationMessages: AIChatMessage[] = [];
1028+
let attempt = 0;
1029+
const maxAttempts = 4;
1030+
1031+
// First attempt - setup diff and hunk map
1032+
const firstAttemptResult = await this.sendRebaseFirstAttempt(repo, baseRef, headRef, source, result, options);
1033+
1034+
if (firstAttemptResult === 'cancelled' || firstAttemptResult == null) {
1035+
return firstAttemptResult;
1036+
}
1037+
1038+
conversationMessages = firstAttemptResult.conversationMessages;
1039+
let rq = firstAttemptResult.response;
1040+
1041+
while (attempt < maxAttempts) {
1042+
const validationResult = this.validateRebaseResponse(rq, result.hunkMap, options);
1043+
if (validationResult.isValid) {
1044+
result.commits = validationResult.commits;
1045+
return rq;
1046+
}
1047+
1048+
Logger.warn(
1049+
undefined,
1050+
'AIProviderService',
1051+
'sendRebaseRequestWithRetry',
1052+
`Validation failed on attempt ${attempt + 1}: ${validationResult.errorMessage}`,
1053+
);
1054+
1055+
// If this was the last attempt, throw the error
1056+
if (attempt === maxAttempts - 1) {
1057+
throw new Error(validationResult.errorMessage);
1058+
}
1059+
1060+
// Prepare retry message for conversation
1061+
conversationMessages.push(
1062+
{ role: 'assistant', content: rq.content },
1063+
{ role: 'user', content: validationResult.retryPrompt },
1064+
);
1065+
1066+
attempt++;
1067+
1068+
// Send retry request
1069+
const currentAttempt = attempt;
1070+
const retryResult = await this.sendRequest(
1071+
'generate-rebase',
1072+
async () => Promise.resolve(conversationMessages),
1073+
m =>
1074+
`Generating ${options?.generateCommits ? 'commits' : 'rebase'} with ${m.name}... (attempt ${
1075+
currentAttempt + 1
1076+
})`,
1077+
source,
1078+
m => ({
1079+
key: 'ai/generate',
1080+
data: {
1081+
type: 'rebase',
1082+
'model.id': m.id,
1083+
'model.provider.id': m.provider.id,
1084+
'model.provider.name': m.provider.name,
1085+
'retry.count': currentAttempt,
1086+
},
1087+
}),
1088+
options,
1089+
);
1090+
1091+
if (retryResult === 'cancelled' || retryResult == null) {
1092+
return retryResult;
1093+
}
1094+
1095+
rq = retryResult;
1096+
}
1097+
1098+
return undefined;
1099+
}
1100+
1101+
private async sendRebaseFirstAttempt(
1102+
repo: Repository,
1103+
baseRef: string,
1104+
headRef: string,
1105+
source: Source,
1106+
result: Mutable<AIRebaseResult>,
1107+
options?: {
1108+
cancellation?: CancellationToken;
1109+
context?: string;
1110+
generating?: Deferred<AIModel>;
1111+
progress?: ProgressOptions;
1112+
generateCommits?: boolean;
1113+
},
1114+
): Promise<{ response: AIRequestResult; conversationMessages: AIChatMessage[] } | 'cancelled' | undefined> {
1115+
let storedPrompt = '';
9891116
const rq = await this.sendRequest(
9901117
'generate-rebase',
9911118
async (model, reporting, cancellation, maxInputTokens, retries) => {
@@ -1042,6 +1169,9 @@ export class AIProviderService implements Disposable {
10421169
);
10431170
if (cancellation.isCancellationRequested) throw new CancellationError();
10441171

1172+
// Store the prompt for later use in conversation messages
1173+
storedPrompt = prompt;
1174+
10451175
const messages: AIChatMessage[] = [{ role: 'user', content: prompt }];
10461176
return messages;
10471177
},
@@ -1064,47 +1194,141 @@ export class AIProviderService implements Disposable {
10641194

10651195
if (rq == null) return undefined;
10661196

1197+
return {
1198+
response: rq,
1199+
conversationMessages: [{ role: 'user', content: storedPrompt }],
1200+
};
1201+
}
1202+
1203+
private validateRebaseResponse(
1204+
rq: AIRequestResult,
1205+
inputHunkMap: { index: number; hunkHeader: string }[],
1206+
options?: {
1207+
generateCommits?: boolean;
1208+
},
1209+
):
1210+
| { isValid: false; errorMessage: string; retryPrompt: string }
1211+
| { isValid: true; commits: AIRebaseResult['commits'] } {
1212+
// if it is wrapped in markdown, we need to strip it
1213+
const content = rq.content.replace(/^\s*```json\s*/, '').replace(/\s*```$/, '');
1214+
1215+
let commits: AIRebaseResult['commits'];
10671216
try {
1068-
// if it is wrapped in markdown, we need to strip it
1069-
const content = rq.content.replace(/^\s*```json\s*/, '').replace(/\s*```$/, '');
10701217
// Parse the JSON content from the result
1071-
result.commits = JSON.parse(content) as AIRebaseResult['commits'];
1218+
commits = JSON.parse(content) as AIRebaseResult['commits'];
1219+
} catch {
1220+
const errorMessage = `Unable to parse ${options?.generateCommits ? 'commits' : 'rebase'} result`;
1221+
const retryPrompt = dedent(`
1222+
Your previous response could not be parsed as valid JSON. Please ensure your response is a valid JSON array of commits with the correct structure.
1223+
1224+
Here was your previous response:
1225+
${rq.content}
10721226
1073-
const inputHunkIndices = result.hunkMap.map(h => h.index);
1074-
const outputHunkIndices = new Set(result.commits.flatMap(c => c.hunks.map(h => h.hunk)));
1227+
Please provide a valid JSON array of commits following this structure:
1228+
[
1229+
{
1230+
"message": "commit message",
1231+
"explanation": "detailed explanation",
1232+
"hunks": [{"hunk": 1}, {"hunk": 2}]
1233+
}
1234+
]
1235+
`);
1236+
1237+
return {
1238+
isValid: false,
1239+
errorMessage: errorMessage,
1240+
retryPrompt: retryPrompt,
1241+
};
1242+
}
10751243

1076-
// Find any missing or extra hunks
1244+
// Validate the structure and hunk assignments
1245+
try {
1246+
const inputHunkIndices = inputHunkMap.map(h => h.index);
1247+
const allOutputHunks = commits.flatMap(c => c.hunks.map(h => h.hunk));
1248+
const outputHunkIndices = new Map(allOutputHunks.map((hunk, index) => [hunk, index]));
10771249
const missingHunks = inputHunkIndices.filter(i => !outputHunkIndices.has(i));
1078-
const extraHunks = [...outputHunkIndices].filter(i => !inputHunkIndices.includes(i));
1079-
if (missingHunks.length > 0 || extraHunks.length > 0) {
1080-
let hunksMessage = '';
1250+
1251+
if (missingHunks.length > 0 || allOutputHunks.length > inputHunkIndices.length) {
1252+
const errorParts: string[] = [];
1253+
const retryParts: string[] = [];
1254+
10811255
if (missingHunks.length > 0) {
10821256
const pluralize = missingHunks.length > 1 ? 's' : '';
1083-
hunksMessage += ` ${missingHunks.length} missing hunk${pluralize}.`;
1257+
errorParts.push(`${missingHunks.length} missing hunk${pluralize}`);
1258+
retryParts.push(`You missed hunk${pluralize} ${missingHunks.join(', ')} in your response`);
10841259
}
1260+
const extraHunks = [...outputHunkIndices.keys()].filter(i => !inputHunkIndices.includes(i));
10851261
if (extraHunks.length > 0) {
10861262
const pluralize = extraHunks.length > 1 ? 's' : '';
1087-
hunksMessage += ` ${extraHunks.length} extra hunk${pluralize}.`;
1263+
errorParts.push(`${extraHunks.length} extra hunk${pluralize}`);
1264+
retryParts.push(
1265+
`You included hunk${pluralize} ${extraHunks.join(', ')} which ${
1266+
extraHunks.length > 1 ? 'were' : 'was'
1267+
} not in the original diff`,
1268+
);
1269+
}
1270+
const duplicateHunks = allOutputHunks.filter((hunk, index) => outputHunkIndices.get(hunk)! !== index);
1271+
const uniqueDuplicates = [...new Set(duplicateHunks)];
1272+
if (uniqueDuplicates.length > 0) {
1273+
const pluralize = uniqueDuplicates.length > 1 ? 's' : '';
1274+
errorParts.push(`${uniqueDuplicates.length} duplicate hunk${pluralize}`);
1275+
retryParts.push(`You used hunk${pluralize} ${uniqueDuplicates.join(', ')} multiple times`);
10881276
}
10891277

1090-
throw new Error(
1091-
`Invalid response in generating ${
1092-
options?.generateCommits ? 'commits' : 'rebase'
1093-
} result.${hunksMessage} Try again or select a different AI model.`,
1094-
);
1095-
}
1096-
} catch (ex) {
1097-
debugger;
1098-
if (ex?.message?.includes('Invalid response in generating')) {
1099-
throw ex;
1278+
const errorMessage = `Invalid response in generating ${
1279+
options?.generateCommits ? 'commits' : 'rebase'
1280+
} result. ${errorParts.join(', ')}.`;
1281+
1282+
const retryPrompt = dedent(`
1283+
Your previous response had issues: ${retryParts.join(', ')}.
1284+
1285+
Please provide a corrected JSON response that:
1286+
1. Includes ALL hunks from 1 to ${Math.max(...inputHunkIndices)} exactly once
1287+
2. Does not include any hunk numbers outside this range
1288+
3. Does not use any hunk more than once
1289+
1290+
Here was your previous response:
1291+
${rq.content}
1292+
1293+
Please provide the corrected JSON array of commits:
1294+
`);
1295+
1296+
return {
1297+
isValid: false,
1298+
errorMessage: errorMessage,
1299+
retryPrompt: retryPrompt,
1300+
};
11001301
}
1101-
throw new Error(`Unable to parse ${options?.generateCommits ? 'commits' : 'rebase'} result`);
1102-
}
11031302

1104-
return {
1105-
...rq,
1106-
...result,
1107-
};
1303+
// If validation passes, return the commits
1304+
return { isValid: true, commits: commits };
1305+
} catch {
1306+
// Handle any errors during hunk validation (e.g., malformed commit structure)
1307+
const errorMessage = `Invalid commit structure in ${
1308+
options?.generateCommits ? 'commits' : 'rebase'
1309+
} result`;
1310+
const retryPrompt = dedent(`
1311+
Your previous response has an invalid commit structure. Each commit must have "message", "explanation", and "hunks" properties, where "hunks" is an array of objects with "hunk" numbers.
1312+
1313+
Here was your previous response:
1314+
${rq.content}
1315+
1316+
Please provide a valid JSON array of commits following this structure:
1317+
[
1318+
{
1319+
"message": "commit message",
1320+
"explanation": "detailed explanation",
1321+
"hunks": [{"hunk": 1}, {"hunk": 2}]
1322+
}
1323+
]
1324+
`);
1325+
1326+
return {
1327+
isValid: false,
1328+
errorMessage: errorMessage,
1329+
retryPrompt: retryPrompt,
1330+
};
1331+
}
11081332
}
11091333

11101334
private async sendRequest<T extends AIActionType>(

0 commit comments

Comments
 (0)