@@ -937,6 +937,18 @@ export class AIProviderService implements Disposable {
937937 return result === 'cancelled' ? result : result != null ? { ...result } : undefined ;
938938 }
939939
940+ /**
941+ * Generates a rebase using AI to organize code changes into logical commits.
942+ *
943+ * This method includes automatic retry logic that validates the AI response and
944+ * continues the conversation if the response has issues like:
945+ * - Missing hunks that were in the original diff
946+ * - Extra hunks that weren't in the original diff
947+ * - Duplicate hunks used multiple times
948+ *
949+ * The method will retry up to 3 times, providing specific feedback to the AI
950+ * about what was wrong with the previous response.
951+ */
940952 async generateRebase (
941953 repo : Repository ,
942954 baseRef : string ,
@@ -986,6 +998,121 @@ export class AIProviderService implements Disposable {
986998 }
987999 }
9881000
1001+ const rq = await this . sendRebaseRequestWithRetry ( repo , baseRef , headRef , source , result , options ) ;
1002+
1003+ if ( rq === 'cancelled' ) return rq ;
1004+
1005+ if ( rq == null ) return undefined ;
1006+
1007+ return {
1008+ ...rq ,
1009+ ...result ,
1010+ } ;
1011+ }
1012+
1013+ private async sendRebaseRequestWithRetry (
1014+ repo : Repository ,
1015+ baseRef : string ,
1016+ headRef : string ,
1017+ source : Source ,
1018+ result : Mutable < AIRebaseResult > ,
1019+ options ?: {
1020+ cancellation ?: CancellationToken ;
1021+ context ?: string ;
1022+ generating ?: Deferred < AIModel > ;
1023+ progress ?: ProgressOptions ;
1024+ generateCommits ?: boolean ;
1025+ } ,
1026+ ) : Promise < AIRequestResult | 'cancelled' | undefined > {
1027+ let conversationMessages : AIChatMessage [ ] = [ ] ;
1028+ let attempt = 0 ;
1029+ const maxAttempts = 4 ;
1030+
1031+ // First attempt - setup diff and hunk map
1032+ const firstAttemptResult = await this . sendRebaseFirstAttempt ( repo , baseRef , headRef , source , result , options ) ;
1033+
1034+ if ( firstAttemptResult === 'cancelled' || firstAttemptResult == null ) {
1035+ return firstAttemptResult ;
1036+ }
1037+
1038+ conversationMessages = firstAttemptResult . conversationMessages ;
1039+ let rq = firstAttemptResult . response ;
1040+
1041+ while ( attempt < maxAttempts ) {
1042+ const validationResult = this . validateRebaseResponse ( rq , result . hunkMap , options ) ;
1043+ if ( validationResult . isValid ) {
1044+ result . commits = validationResult . commits ;
1045+ return rq ;
1046+ }
1047+
1048+ Logger . warn (
1049+ undefined ,
1050+ 'AIProviderService' ,
1051+ 'sendRebaseRequestWithRetry' ,
1052+ `Validation failed on attempt ${ attempt + 1 } : ${ validationResult . errorMessage } ` ,
1053+ ) ;
1054+
1055+ // If this was the last attempt, throw the error
1056+ if ( attempt === maxAttempts - 1 ) {
1057+ throw new Error ( validationResult . errorMessage ) ;
1058+ }
1059+
1060+ // Prepare retry message for conversation
1061+ conversationMessages . push (
1062+ { role : 'assistant' , content : rq . content } ,
1063+ { role : 'user' , content : validationResult . retryPrompt } ,
1064+ ) ;
1065+
1066+ attempt ++ ;
1067+
1068+ // Send retry request
1069+ const currentAttempt = attempt ;
1070+ const retryResult = await this . sendRequest (
1071+ 'generate-rebase' ,
1072+ async ( ) => Promise . resolve ( conversationMessages ) ,
1073+ m =>
1074+ `Generating ${ options ?. generateCommits ? 'commits' : 'rebase' } with ${ m . name } ... (attempt ${
1075+ currentAttempt + 1
1076+ } )`,
1077+ source ,
1078+ m => ( {
1079+ key : 'ai/generate' ,
1080+ data : {
1081+ type : 'rebase' ,
1082+ 'model.id' : m . id ,
1083+ 'model.provider.id' : m . provider . id ,
1084+ 'model.provider.name' : m . provider . name ,
1085+ 'retry.count' : currentAttempt ,
1086+ } ,
1087+ } ) ,
1088+ options ,
1089+ ) ;
1090+
1091+ if ( retryResult === 'cancelled' || retryResult == null ) {
1092+ return retryResult ;
1093+ }
1094+
1095+ rq = retryResult ;
1096+ }
1097+
1098+ return undefined ;
1099+ }
1100+
1101+ private async sendRebaseFirstAttempt (
1102+ repo : Repository ,
1103+ baseRef : string ,
1104+ headRef : string ,
1105+ source : Source ,
1106+ result : Mutable < AIRebaseResult > ,
1107+ options ?: {
1108+ cancellation ?: CancellationToken ;
1109+ context ?: string ;
1110+ generating ?: Deferred < AIModel > ;
1111+ progress ?: ProgressOptions ;
1112+ generateCommits ?: boolean ;
1113+ } ,
1114+ ) : Promise < { response : AIRequestResult ; conversationMessages : AIChatMessage [ ] } | 'cancelled' | undefined > {
1115+ let storedPrompt = '' ;
9891116 const rq = await this . sendRequest (
9901117 'generate-rebase' ,
9911118 async ( model , reporting , cancellation , maxInputTokens , retries ) => {
@@ -1042,6 +1169,9 @@ export class AIProviderService implements Disposable {
10421169 ) ;
10431170 if ( cancellation . isCancellationRequested ) throw new CancellationError ( ) ;
10441171
1172+ // Store the prompt for later use in conversation messages
1173+ storedPrompt = prompt ;
1174+
10451175 const messages : AIChatMessage [ ] = [ { role : 'user' , content : prompt } ] ;
10461176 return messages ;
10471177 } ,
@@ -1064,47 +1194,141 @@ export class AIProviderService implements Disposable {
10641194
10651195 if ( rq == null ) return undefined ;
10661196
1197+ return {
1198+ response : rq ,
1199+ conversationMessages : [ { role : 'user' , content : storedPrompt } ] ,
1200+ } ;
1201+ }
1202+
1203+ private validateRebaseResponse (
1204+ rq : AIRequestResult ,
1205+ inputHunkMap : { index : number ; hunkHeader : string } [ ] ,
1206+ options ?: {
1207+ generateCommits ?: boolean ;
1208+ } ,
1209+ ) :
1210+ | { isValid : false ; errorMessage : string ; retryPrompt : string }
1211+ | { isValid : true ; commits : AIRebaseResult [ 'commits' ] } {
1212+ // if it is wrapped in markdown, we need to strip it
1213+ const content = rq . content . replace ( / ^ \s * ` ` ` j s o n \s * / , '' ) . replace ( / \s * ` ` ` $ / , '' ) ;
1214+
1215+ let commits : AIRebaseResult [ 'commits' ] ;
10671216 try {
1068- // if it is wrapped in markdown, we need to strip it
1069- const content = rq . content . replace ( / ^ \s * ` ` ` j s o n \s * / , '' ) . replace ( / \s * ` ` ` $ / , '' ) ;
10701217 // Parse the JSON content from the result
1071- result . commits = JSON . parse ( content ) as AIRebaseResult [ 'commits' ] ;
1218+ commits = JSON . parse ( content ) as AIRebaseResult [ 'commits' ] ;
1219+ } catch {
1220+ const errorMessage = `Unable to parse ${ options ?. generateCommits ? 'commits' : 'rebase' } result` ;
1221+ const retryPrompt = dedent ( `
1222+ Your previous response could not be parsed as valid JSON. Please ensure your response is a valid JSON array of commits with the correct structure.
1223+
1224+ Here was your previous response:
1225+ ${ rq . content }
10721226
1073- const inputHunkIndices = result . hunkMap . map ( h => h . index ) ;
1074- const outputHunkIndices = new Set ( result . commits . flatMap ( c => c . hunks . map ( h => h . hunk ) ) ) ;
1227+ Please provide a valid JSON array of commits following this structure:
1228+ [
1229+ {
1230+ "message": "commit message",
1231+ "explanation": "detailed explanation",
1232+ "hunks": [{"hunk": 1}, {"hunk": 2}]
1233+ }
1234+ ]
1235+ ` ) ;
1236+
1237+ return {
1238+ isValid : false ,
1239+ errorMessage : errorMessage ,
1240+ retryPrompt : retryPrompt ,
1241+ } ;
1242+ }
10751243
1076- // Find any missing or extra hunks
1244+ // Validate the structure and hunk assignments
1245+ try {
1246+ const inputHunkIndices = inputHunkMap . map ( h => h . index ) ;
1247+ const allOutputHunks = commits . flatMap ( c => c . hunks . map ( h => h . hunk ) ) ;
1248+ const outputHunkIndices = new Map ( allOutputHunks . map ( ( hunk , index ) => [ hunk , index ] ) ) ;
10771249 const missingHunks = inputHunkIndices . filter ( i => ! outputHunkIndices . has ( i ) ) ;
1078- const extraHunks = [ ...outputHunkIndices ] . filter ( i => ! inputHunkIndices . includes ( i ) ) ;
1079- if ( missingHunks . length > 0 || extraHunks . length > 0 ) {
1080- let hunksMessage = '' ;
1250+
1251+ if ( missingHunks . length > 0 || allOutputHunks . length > inputHunkIndices . length ) {
1252+ const errorParts : string [ ] = [ ] ;
1253+ const retryParts : string [ ] = [ ] ;
1254+
10811255 if ( missingHunks . length > 0 ) {
10821256 const pluralize = missingHunks . length > 1 ? 's' : '' ;
1083- hunksMessage += ` ${ missingHunks . length } missing hunk${ pluralize } .` ;
1257+ errorParts . push ( `${ missingHunks . length } missing hunk${ pluralize } ` ) ;
1258+ retryParts . push ( `You missed hunk${ pluralize } ${ missingHunks . join ( ', ' ) } in your response` ) ;
10841259 }
1260+ const extraHunks = [ ...outputHunkIndices . keys ( ) ] . filter ( i => ! inputHunkIndices . includes ( i ) ) ;
10851261 if ( extraHunks . length > 0 ) {
10861262 const pluralize = extraHunks . length > 1 ? 's' : '' ;
1087- hunksMessage += ` ${ extraHunks . length } extra hunk${ pluralize } .` ;
1263+ errorParts . push ( `${ extraHunks . length } extra hunk${ pluralize } ` ) ;
1264+ retryParts . push (
1265+ `You included hunk${ pluralize } ${ extraHunks . join ( ', ' ) } which ${
1266+ extraHunks . length > 1 ? 'were' : 'was'
1267+ } not in the original diff`,
1268+ ) ;
1269+ }
1270+ const duplicateHunks = allOutputHunks . filter ( ( hunk , index ) => outputHunkIndices . get ( hunk ) ! !== index ) ;
1271+ const uniqueDuplicates = [ ...new Set ( duplicateHunks ) ] ;
1272+ if ( uniqueDuplicates . length > 0 ) {
1273+ const pluralize = uniqueDuplicates . length > 1 ? 's' : '' ;
1274+ errorParts . push ( `${ uniqueDuplicates . length } duplicate hunk${ pluralize } ` ) ;
1275+ retryParts . push ( `You used hunk${ pluralize } ${ uniqueDuplicates . join ( ', ' ) } multiple times` ) ;
10881276 }
10891277
1090- throw new Error (
1091- `Invalid response in generating ${
1092- options ?. generateCommits ? 'commits' : 'rebase'
1093- } result.${ hunksMessage } Try again or select a different AI model.`,
1094- ) ;
1095- }
1096- } catch ( ex ) {
1097- debugger ;
1098- if ( ex ?. message ?. includes ( 'Invalid response in generating' ) ) {
1099- throw ex ;
1278+ const errorMessage = `Invalid response in generating ${
1279+ options ?. generateCommits ? 'commits' : 'rebase'
1280+ } result. ${ errorParts . join ( ', ' ) } .`;
1281+
1282+ const retryPrompt = dedent ( `
1283+ Your previous response had issues: ${ retryParts . join ( ', ' ) } .
1284+
1285+ Please provide a corrected JSON response that:
1286+ 1. Includes ALL hunks from 1 to ${ Math . max ( ...inputHunkIndices ) } exactly once
1287+ 2. Does not include any hunk numbers outside this range
1288+ 3. Does not use any hunk more than once
1289+
1290+ Here was your previous response:
1291+ ${ rq . content }
1292+
1293+ Please provide the corrected JSON array of commits:
1294+ ` ) ;
1295+
1296+ return {
1297+ isValid : false ,
1298+ errorMessage : errorMessage ,
1299+ retryPrompt : retryPrompt ,
1300+ } ;
11001301 }
1101- throw new Error ( `Unable to parse ${ options ?. generateCommits ? 'commits' : 'rebase' } result` ) ;
1102- }
11031302
1104- return {
1105- ...rq ,
1106- ...result ,
1107- } ;
1303+ // If validation passes, return the commits
1304+ return { isValid : true , commits : commits } ;
1305+ } catch {
1306+ // Handle any errors during hunk validation (e.g., malformed commit structure)
1307+ const errorMessage = `Invalid commit structure in ${
1308+ options ?. generateCommits ? 'commits' : 'rebase'
1309+ } result`;
1310+ const retryPrompt = dedent ( `
1311+ Your previous response has an invalid commit structure. Each commit must have "message", "explanation", and "hunks" properties, where "hunks" is an array of objects with "hunk" numbers.
1312+
1313+ Here was your previous response:
1314+ ${ rq . content }
1315+
1316+ Please provide a valid JSON array of commits following this structure:
1317+ [
1318+ {
1319+ "message": "commit message",
1320+ "explanation": "detailed explanation",
1321+ "hunks": [{"hunk": 1}, {"hunk": 2}]
1322+ }
1323+ ]
1324+ ` ) ;
1325+
1326+ return {
1327+ isValid : false ,
1328+ errorMessage : errorMessage ,
1329+ retryPrompt : retryPrompt ,
1330+ } ;
1331+ }
11081332 }
11091333
11101334 private async sendRequest < T extends AIActionType > (
0 commit comments