diff --git a/server/aws-lsp-codewhisperer/src/language-server/inline-completion/handler/editCompletionHandler.ts b/server/aws-lsp-codewhisperer/src/language-server/inline-completion/handler/editCompletionHandler.ts index bbfd441699..42a45f1f8c 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/inline-completion/handler/editCompletionHandler.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/inline-completion/handler/editCompletionHandler.ts @@ -38,6 +38,7 @@ import { AmazonQError, AmazonQServiceConnectionExpiredError } from '../../../sha import { DocumentChangedListener } from '../documentChangedListener' import { EMPTY_RESULT, EDIT_DEBOUNCE_INTERVAL_MS } from '../contants/constants' import { StreakTracker } from '../tracker/streakTracker' +import { processEditSuggestion } from '../utils/diffUtils' export class EditCompletionHandler { private readonly editsEnabled: boolean @@ -396,6 +397,13 @@ export class EditCompletionHandler { textDocument?.uri || '' ) + const processedSuggestion = processEditSuggestion( + suggestion.content, + session.startPosition, + session.document + ) + const isInlineEdit = processedSuggestion.type === SuggestionType.EDIT + if (isSimilarToRejected) { // Mark as rejected in the session session.setSuggestionState(suggestion.itemId, 'Reject') @@ -405,14 +413,14 @@ export class EditCompletionHandler { // Return empty item that will be filtered out return { insertText: '', - isInlineEdit: true, + isInlineEdit: isInlineEdit, itemId: suggestion.itemId, } } return { - insertText: suggestion.content, - isInlineEdit: true, + insertText: processedSuggestion.suggestionContent, + isInlineEdit: isInlineEdit, itemId: suggestion.itemId, } }) diff --git a/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.test.ts b/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.test.ts index 40caa458a5..20d5890d77 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.test.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.test.ts @@ -1,5 +1,315 @@ import * as assert from 'assert' -import { getAddedAndDeletedLines, getCharacterDifferences, generateDiffContexts } from './diffUtils' +import { + categorizeUnifieddiff, + extractAdditions, + getAddedAndDeletedLines, + getCharacterDifferences, + generateDiffContexts, +} from './diffUtils' + +describe('extractAdditions', function () { + interface Case { + udiff: string + expected: string + } + + const cases: Case[] = [ + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/MathUtil.java +@@ -1,9 +1,10 @@ + public class MathUtil { + // write a function to add 2 numbers + public static int add(int a, int b) { + ++ return a + b; + } + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + return a - b;`, + expected: ' return a + b;', + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/MathUtil.java +@@ -1,9 +1,17 @@ + public class MathUtil { + // write a function to add 2 numbers + public static int add(int a, int b) { + ++ if (a > Integer.MAX_VALUE - b){ ++ throw new IllegalArgumentException("Overflow!"); ++ } ++ else if (a < Integer.MIN_VALUE - b){ ++ throw new IllegalArgumentException("Underflow"); ++ } ++ else{ ++ return a + b; ++ } + } + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + return a - b;`, + expected: ` if (a > Integer.MAX_VALUE - b){ + throw new IllegalArgumentException("Overflow!"); + } + else if (a < Integer.MIN_VALUE - b){ + throw new IllegalArgumentException("Underflow"); + } + else{ + return a + b; + }`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java +@@ -6,7 +6,11 @@ + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + return a - b; + } +- ++ ++ // write a function to multiply 2 numbers ++ public static int multiply(int a, int b) { ++ return a * b; ++ } + }`, + expected: ` + // write a function to multiply 2 numbers + public static int multiply(int a, int b) { + return a * b; + }`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java +@@ -3,7 +3,9 @@ + public static int add(int a, int b) { + return a + b; + } + + // write a function to subtract 2 numbers +- ++ public static int subtract(int a, int b) { ++ return a - b; ++ } + }`, + expected: ` public static int subtract(int a, int b) { + return a - b; + }`, + }, + ] + + for (let i = 0; i < cases.length; i++) { + it(`case ${i}`, function () { + const c = cases[i] + const udiff = c.udiff + const expected = c.expected + + const actual = extractAdditions(udiff) + assert.strictEqual(actual, expected) + }) + } +}) + +describe('categorizeUnifieddiffV2v2 should return correct type (addOnly, edit, deleteOnly)', function () { + interface Case { + udiff: string + } + + describe('addOnly', function () { + const addOnlyCases: Case[] = [ + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java +@@ -6,7 +6,11 @@ + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + return a - b; + } +- ++ ++ // write a function to multiply 2 numbers ++ public static int multiply(int a, int b) { ++ return a * b; ++ } + }`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java +@@ -6,7 +6,11 @@ + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + return a - b; + } +- ++ ++ ++ // write a function to multiply 2 numbers ++ public static int multiply(int a, int b) { ++ return a * b; ++ } + }`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java +@@ -6,7 +6,11 @@ + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + return a - b; + } +- ++ ++ // write a function to multiply 2 numbers ++ public static int multiply(int a, int b) { ++ return a * b; ++ } + }`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/MathUtil.java +@@ -1,9 +1,10 @@ + public class MathUtil { + // write a function to add 2 numbers + public static int add(int a, int b) { + ++ return a + b; + } + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + return a - b;`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java +@@ -3,7 +3,9 @@ + public static int add(int a, int b) { + return a + b; + } + + // write a function to subtract 2 numbers +- ++ public static int subtract(int a, int b) { ++ return a - b; ++ } + }`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator-2/src/main/hello/MathUtil.java +@@ -4,8 +4,8 @@ + return a + b; + } + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { +- return ++ return a - b; + } + }`, + }, + { + udiff: `--- file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/LRUCache.java ++++ file:///Volumes/workplace/ide/sample_projects/Calculator/src/main/hello/LRUCache.java +@@ -7,7 +7,11 @@ + private Map map; + private DoubleLinkedList list; + private int capacity; + + // get +- public LruCache ++ public LruCache(int capacity) { ++ this.capacity = capacity; ++ map = new HashMap<>(); ++ list = new DoubleLinkedList(); ++ } + }`, + }, + ] + + for (let i = 0; i < addOnlyCases.length; i++) { + it(`case ${i}`, function () { + const actual = categorizeUnifieddiff(addOnlyCases[i].udiff) + assert.strictEqual(actual, 'addOnly') + }) + } + }) + + describe('edit', function () { + const cases: Case[] = [ + { + udiff: `--- a/src/main/hello/MathUtil.java ++++ b/src/main/hello/MathUtil.java +@@ -1,11 +1,11 @@ + public class MathUtil { + // write a function to add 2 numbers +- public static int add(int a, int b) { ++ public static double add(double a, double b) { + return a + b; + } + + // write a function to subtract 2 numbers + public static int subtract(int a, int b) { + public static double subtract(double a, double b) { + return a - b; + } + }`, + }, + { + udiff: `--- a/server/aws-lsp-codewhisperer/src/shared/codeWhispererService.ts ++++ b/server/aws-lsp-codewhisperer/src/shared/codeWhispererService.ts +@@ -502,11 +502,7 @@ export class CodeWhispererServiceToken extends CodeWhispererServiceBase { + : undefined + } + +- private withProfileArn(request: T): T { +- if (!this.profileArn) return request +- +- return { ...request, profileArn: this.profileArn } +- } ++ // ddddddddddddddddd + + async generateSuggestions(request: BaseGenerateSuggestionsRequest): Promise { + // Cast is now safe because GenerateTokenSuggestionsRequest extends GenerateCompletionsRequest`, + }, + { + udiff: `--- file:///Users/atona/workplace/NEP/language-servers/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/textDocumentUtils.ts ++++ file:///Users/atona/workplace/NEP/language-servers/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/textDocumentUtils.ts +@@ -15,11 +15,11 @@ + return '' + } + } + + export const getTextDocument = async (uri: string, workspace: any, logging: any): Promise => { +- let ++ if (!textDocument) { + if (!textDocument) { + try { + const content = await workspace.fs.readFile(URI.parse(uri).fsPath) + const languageId = getLanguageIdFromUri(uri) + textDocument = TextDocument.create(uri, languageId, 0, content)`, + }, + ] + + for (let i = 0; i < cases.length; i++) { + it(`case ${i}`, function () { + const actual = categorizeUnifieddiff(cases[i].udiff) + assert.strictEqual(actual, 'edit') + }) + } + }) +}) describe('diffUtils', () => { describe('getAddedAndDeletedLines', () => { diff --git a/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.ts b/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.ts index 77e6c28e6c..8c67a92746 100644 --- a/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.ts +++ b/server/aws-lsp-codewhisperer/src/language-server/inline-completion/utils/diffUtils.ts @@ -6,6 +6,9 @@ import * as diff from 'diff' import { CodeWhispererSupplementalContext, CodeWhispererSupplementalContextItem } from '../../../shared/models/model' import { trimSupplementalContexts } from '../../../shared/supplementalContextUtil/supplementalContextUtil' +import { Position, TextDocument, Range } from '@aws/language-server-runtimes/protocol' +import { SuggestionType } from '../../../shared/codeWhispererService' +import { getPrefixSuffixOverlap } from './mergeRightUtils' /** * Generates a unified diff format between old and new file contents @@ -197,3 +200,203 @@ export function getCharacterDifferences(addedLines: string[], deletedLines: stri charactersRemoved: deletedText.length - lcsLen, } } + +export function processEditSuggestion( + unifiedDiff: string, + triggerPosition: Position, + document: TextDocument +): { suggestionContent: string; type: SuggestionType } { + // Assume it's an edit if anything goes wrong, at the very least it will not be rendered incorrectly + let diffCategory: ReturnType = 'edit' + try { + diffCategory = categorizeUnifieddiff(unifiedDiff) + } catch (e) { + // We dont have logger here.... + diffCategory = 'edit' + } + + if (diffCategory === 'addOnly') { + const preprocessAdd = extractAdditions(unifiedDiff) + const leftContextAtTriggerLine = document.getText( + Range.create(Position.create(triggerPosition.line, 0), triggerPosition) + ) + /** + * SHOULD NOT remove the entire overlapping string, the way inline suggestion prefix matching work depends on where it triggers + * For example (^ note where user triggers) + * console.lo + * ^ + * if LSP returns `g('foo')` instead of `.log()` the suggestion will be discarded because prefix doesnt match + */ + const processedAdd = removeOverlapCodeFromSuggestion(leftContextAtTriggerLine, preprocessAdd) + return { + suggestionContent: processedAdd, + type: SuggestionType.COMPLETION, + } + } else { + return { + suggestionContent: unifiedDiff, + type: SuggestionType.EDIT, + } + } +} + +// TODO: MAKE it a class and abstract all the business parsing logic within the classsssss so we dont need to redo the same thing again and again +interface UnifiedDiff { + linesWithoutHeaders: string[] + firstMinusIndex: number + firstPlusIndex: number + minusIndexes: number[] + plusIndexes: number[] +} + +// TODO: refine +export function readUdiff(unifiedDiff: string): UnifiedDiff { + const lines = unifiedDiff.split('\n') + const headerEndIndex = lines.findIndex(l => l.startsWith('@@')) + if (headerEndIndex === -1) { + throw new Error('not able to parse') + } + const relevantLines = lines.slice(headerEndIndex + 1) + if (relevantLines.length === 0) { + throw new Error('not able to parse') + } + + const minusIndexes: number[] = [] + const plusIndexes: number[] = [] + for (let i = 0; i < relevantLines.length; i++) { + const l = relevantLines[i] + if (l.startsWith('-')) { + minusIndexes.push(i) + } else if (l.startsWith('+')) { + plusIndexes.push(i) + } + } + + const firstMinusIndex = relevantLines.findIndex(s => s.startsWith('-')) + const firstPlusIndex = relevantLines.findIndex(s => s.startsWith('+')) + + return { + linesWithoutHeaders: relevantLines, + firstMinusIndex: firstMinusIndex, + firstPlusIndex: firstPlusIndex, + minusIndexes: minusIndexes, + plusIndexes: plusIndexes, + } +} + +export function categorizeUnifieddiff(unifiedDiff: string): 'addOnly' | 'deleteOnly' | 'edit' { + try { + const d = readUdiff(unifiedDiff) + const firstMinusIndex = d.firstMinusIndex + const firstPlusIndex = d.firstPlusIndex + const diffWithoutHeaders = d.linesWithoutHeaders + + if (firstMinusIndex === -1 && firstPlusIndex === -1) { + return 'edit' + } + + if (firstMinusIndex === -1 && firstPlusIndex !== -1) { + return 'addOnly' + } + + if (firstMinusIndex !== -1 && firstPlusIndex === -1) { + return 'deleteOnly' + } + + const minusIndexes = d.minusIndexes + const plusIndexes = d.plusIndexes + + // If there are multiple (> 1) non empty '-' lines, it must be edit + const c = minusIndexes.reduce((acc: number, cur: number): number => { + if (diffWithoutHeaders[cur].trim().length > 0) { + return acc++ + } + + return acc + }, 0) + + if (c > 1) { + return 'edit' + } + + // If last '-' line is followed by '+' block, it could be addonly + if (plusIndexes[0] === minusIndexes[minusIndexes.length - 1] + 1) { + const minusLine = diffWithoutHeaders[minusIndexes[minusIndexes.length - 1]].substring(1) + const pluscode = extractAdditions(unifiedDiff) + + // If minusLine subtract the longest common substring of minusLine and plugcode and it's empty string, it's addonly + const commonPrefix = longestCommonPrefix(minusLine, pluscode) + if (minusLine.substring(commonPrefix.length).trim().length === 0) { + return 'addOnly' + } + } + + return 'edit' + } catch (e) { + return 'edit' + } +} + +// TODO: current implementation here assumes service only return 1 chunk of edits (consecutive lines) and hacky +export function extractAdditions(unifiedDiff: string): string { + const lines = unifiedDiff.split('\n') + let completionSuggestion = '' + let isInAdditionBlock = false + + for (const line of lines) { + // Skip diff headers (files) + if (line.startsWith('+++') || line.startsWith('---')) { + continue + } + + // Skip hunk headers (@@ lines) + if (line.startsWith('@@')) { + continue + } + + // Handle additions + if (line.startsWith('+')) { + completionSuggestion += line.substring(1) + '\n' + isInAdditionBlock = true + } else if (isInAdditionBlock && !line.startsWith('+')) { + // End of addition block + isInAdditionBlock = false + } + } + + // Remove trailing newline + return completionSuggestion.trimEnd() +} + +/** + * + * example + * code = 'return' + * suggestion = 'return a + b;' + * output = ' a + b;' + */ +export function removeOverlapCodeFromSuggestion(code: string, suggestion: string): string { + const suggestionLines = suggestion.split('\n') + const firstLineSuggestion = suggestionLines[0] + + // Find the common string in code surfix and prefix of suggestion + const s = getPrefixSuffixOverlap(code, firstLineSuggestion) + + // Remove overlap s from suggestion + return suggestion.substring(s.length) +} + +export function longestCommonPrefix(str1: string, str2: string): string { + const minLength = Math.min(str1.length, str2.length) + let prefix = '' + + for (let i = 0; i < minLength; i++) { + if (str1[i] === str2[i]) { + prefix += str1[i] + } else { + break + } + } + + return prefix +} diff --git a/server/aws-lsp-codewhisperer/src/shared/codeWhispererService.ts b/server/aws-lsp-codewhisperer/src/shared/codeWhispererService.ts index 9c229ccd86..27dac024af 100644 --- a/server/aws-lsp-codewhisperer/src/shared/codeWhispererService.ts +++ b/server/aws-lsp-codewhisperer/src/shared/codeWhispererService.ts @@ -522,20 +522,21 @@ export class CodeWhispererServiceToken extends CodeWhispererServiceBase { } const beforeApiCall = Date.now() - let recentEditsLogStr = '' - const recentEdits = tokenRequest.supplementalContexts?.filter(it => it.type === 'PreviousEditorState') - if (recentEdits) { - if (recentEdits.length === 0) { - recentEditsLogStr += `No recent edits` - } else { - recentEditsLogStr += '\n' - for (let i = 0; i < recentEdits.length; i++) { - const e = recentEdits[i] - recentEditsLogStr += `[recentEdits ${i}th]:\n` - recentEditsLogStr += `${e.content}\n` - } - } - } + // TODO: Should make context log as a dev option, too noisy, comment it out temporarily + // let recentEditsLogStr = '' + // const recentEdits = tokenRequest.supplementalContexts?.filter(it => it.type === 'PreviousEditorState') + // if (recentEdits) { + // if (recentEdits.length === 0) { + // recentEditsLogStr += `No recent edits` + // } else { + // recentEditsLogStr += '\n' + // for (let i = 0; i < recentEdits.length; i++) { + // const e = recentEdits[i] + // recentEditsLogStr += `[recentEdits ${i}th]:\n` + // recentEditsLogStr += `${e.content}\n` + // } + // } + // } logstr += `@@request metadata@@ "endpoint": ${this.codeWhispererEndpoint}, @@ -545,8 +546,8 @@ export class CodeWhispererServiceToken extends CodeWhispererServiceBase { rightContextLength: ${request.fileContext.rightFileContent.length}, "language": ${tokenRequest.fileContext.programmingLanguage.languageName}, "supplementalContextCount": ${tokenRequest.supplementalContexts?.length ?? 0}, - "request.nextToken": ${tokenRequest.nextToken}, - "recentEdits": ${recentEditsLogStr}\n` + "request.nextToken": ${tokenRequest.nextToken}` + // "recentEdits": ${recentEditsLogStr}\n` const response = await this.client.generateCompletions(this.withProfileArn(tokenRequest)).promise() @@ -566,7 +567,7 @@ export class CodeWhispererServiceToken extends CodeWhispererServiceBase { "sessionId": ${responseContext.codewhispererSessionId}, "response.completions.length": ${response.completions?.length ?? 0}, "response.predictions.length": ${response.predictions?.length ?? 0}, - "predictionType": ${tokenRequest.predictionTypes?.toString() ?? ''}, + "predictionType": ${tokenRequest.predictionTypes?.toString() ?? 'Not specified (COMPLETIONS)'}, "latency": ${Date.now() - beforeApiCall}, "response.nextToken": ${response.nextToken}, "firstSuggestion": ${firstSuggestionLogstr}` @@ -599,6 +600,7 @@ export class CodeWhispererServiceToken extends CodeWhispererServiceBase { } } + // Backward compatibility, completions will be returned if predictionType is not specified (either Completion or Edit) for (const recommendation of apiResponse?.completions ?? []) { Object.assign(recommendation, { itemId: this.generateItemId() }) }