Skip to content

Commit 7fc2b48

Browse files
authored
codewhisperer: add new classifier for more languages (#3863)
* codewhisperer: add new classifier for more languages * update new classifier coefficients * combine windows 10 & 11 * do not match a single char when matching keyword * remove extra whitespace * use the right left context on current line * fix detekt
1 parent d34711e commit 7fc2b48

File tree

6 files changed

+548
-29
lines changed

6 files changed

+548
-29
lines changed

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/CodeWhispererProgrammingLanguage.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ abstract class CodeWhispererProgrammingLanguage {
2828

2929
open fun isImportAdderSupported(): Boolean = false
3030

31-
open fun isClassifierSupported(): Boolean = false
31+
open fun isClassifierSupported(): Boolean = true
3232

3333
open fun isAllClassifier(): Boolean = false
3434

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/languages/CodeWhispererPlainText.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class CodeWhispererPlainText private constructor() : CodeWhispererProgrammingLan
1111

1212
override fun toTelemetryType(): CodewhispererLanguage = CodewhispererLanguage.Plaintext
1313

14+
override fun isClassifierSupported(): Boolean = false
15+
1416
companion object {
1517
const val ID = "plaintext"
1618

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/language/languages/CodeWhispererUnknownLanguage.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class CodeWhispererUnknownLanguage private constructor() : CodeWhispererProgramm
1111

1212
override fun toTelemetryType(): CodewhispererLanguage = CodewhispererLanguage.Unknown
1313

14+
override fun isClassifierSupported(): Boolean = false
15+
1416
companion object {
1517
const val ID = "unknown"
1618

jetbrains-core/src/software/aws/toolkits/jetbrains/services/codewhisperer/service/CodeWhispererAutoTriggerService.kt

Lines changed: 136 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,13 @@ class CodeWhispererAutoTriggerService : CodeWhispererAutoTriggerHandler, Disposa
157157
}
158158

159159
val leftContextLines = caretContext.leftFileContext.split(Regex("\r?\n"))
160-
val leftContextAtCurrentLine = caretContext.leftContextOnCurrentLine
161-
val keyword = leftContextAtCurrentLine.trim().split(" ").let { tokens ->
160+
val leftContextLength = caretContext.leftFileContext.length
161+
val leftContextAtCurrentLine = if (leftContextLines.size - 1 >= 0) leftContextLines[leftContextLines.size - 1] else ""
162+
var keyword = ""
163+
val lastToken = leftContextAtCurrentLine.trim().split(" ").let { tokens ->
162164
if (tokens.size - 1 >= 0) tokens[tokens.size - 1] else ""
163165
}
166+
if (lastToken.length > 1) keyword = lastToken
164167

165168
val lengthOfLeftCurrent = leftContextAtCurrentLine.length
166169
val lengthOfLeftPrev = if (leftContextLines.size - 2 >= 0) {
@@ -171,36 +174,55 @@ class CodeWhispererAutoTriggerService : CodeWhispererAutoTriggerHandler, Disposa
171174

172175
val rightContext = caretContext.rightFileContext
173176
val lengthOfRight = rightContext.trim().length
174-
val triggerTypeCoefficient = CodeWhispererClassifierConstants.triggerTypeCoefficientMap[automatedTriggerType] ?: 0.0
177+
178+
val isExperimentGroup = CodeWhispererUserGroupSettings.getInstance().getUserGroup() == CodeWhispererUserGroup.Classifier
179+
180+
val triggerTypeCoefficient = if (isExperimentGroup) {
181+
CodeWhispererClassifierConstants.triggerTypeCoefficientMapExp[automatedTriggerType] ?: 0.0
182+
} else CodeWhispererClassifierConstants.triggerTypeCoefficientMap[automatedTriggerType] ?: 0.0
175183

176184
val osCoefficient: Double = if (SystemInfo.isMac) {
177-
CodeWhispererClassifierConstants.osMap["Mac OS X"] ?: 0.0
185+
if (isExperimentGroup) {
186+
CodeWhispererClassifierConstants.osMapExp["Mac OS X"] ?: 0.0
187+
} else CodeWhispererClassifierConstants.osMap["Mac OS X"] ?: 0.0
178188
} else if (SystemInfo.isWindows) {
179189
val osVersion = SystemInfo.OS_VERSION
180-
if (osVersion.contains("11", true)) {
181-
CodeWhispererClassifierConstants.osMap["Windows 10"]
182-
} else if (osVersion.contains("10", true)) {
183-
CodeWhispererClassifierConstants.osMap["Windows 10"]
190+
if (osVersion.contains("11", true) || osVersion.contains("10", true)) {
191+
if (isExperimentGroup) {
192+
CodeWhispererClassifierConstants.osMapExp["Windows 10"]
193+
} else {
194+
CodeWhispererClassifierConstants.osMap["Windows 10"]
195+
}
184196
} else if (osVersion.contains("7", true)) {
185-
CodeWhispererClassifierConstants.osMap["Windows 7"]
197+
if (isExperimentGroup) {
198+
CodeWhispererClassifierConstants.osMapExp["Windows"]
199+
} else {
200+
CodeWhispererClassifierConstants.osMap["Windows 7"]
201+
}
186202
} else {
187-
0.0
203+
if (isExperimentGroup) CodeWhispererClassifierConstants.osMapExp["Windows"] else 0.0
188204
}
189205
} else {
190206
0.0
191207
} ?: 0.0
192208

193209
val lastCharCoefficient = if (leftContextAtCurrentLine.length - 1 >= 0) {
194-
CodeWhispererClassifierConstants.coefficientsMap[leftContextAtCurrentLine[leftContextAtCurrentLine.length - 1].toString()] ?: 0.0
210+
if (isExperimentGroup) {
211+
CodeWhispererClassifierConstants.coefficientsMapExp[leftContextAtCurrentLine[leftContextAtCurrentLine.length - 1].toString()] ?: 0.0
212+
} else CodeWhispererClassifierConstants.coefficientsMap[leftContextAtCurrentLine[leftContextAtCurrentLine.length - 1].toString()] ?: 0.0
195213
} else {
196214
0.0
197215
}
198216

199-
val keywordCoefficient = CodeWhispererClassifierConstants.coefficientsMap[keyword] ?: 0.0
200-
val languageCoefficient = CodeWhispererClassifierConstants.languageMap[language] ?: 0.0
201-
val ideCoefficient = 0
217+
val keywordCoefficient = if (isExperimentGroup) {
218+
CodeWhispererClassifierConstants.coefficientsMapExp[keyword] ?: 0.0
219+
} else CodeWhispererClassifierConstants.coefficientsMap[keyword] ?: 0.0
220+
val languageCoefficient = if (isExperimentGroup) {
221+
CodeWhispererClassifierConstants.languageMapExp[language] ?: 0.0
222+
} else CodeWhispererClassifierConstants.languageMap[language] ?: 0.0
223+
val ideCoefficient = 0.0
202224

203-
val lineDiff = lastInvocationLineNum?.let { (caretPosition.line.toDouble() - it) } ?: 0.0
225+
val lineDiff = if (isExperimentGroup) 0.0 else lastInvocationLineNum?.let { (caretPosition.line.toDouble() - it) } ?: 0.0
204226

205227
var previousOneAccept: Double = 0.0
206228
var previousOneReject: Double = 0.0
@@ -212,28 +234,81 @@ class CodeWhispererAutoTriggerService : CodeWhispererAutoTriggerHandler, Disposa
212234
previousOneOther = 0.0
213235
} else {
214236
previousOneAccept =
215-
if (previousOneDecision == CodewhispererPreviousSuggestionState.Accept) CodeWhispererClassifierConstants.prevDecisionAcceptCoefficient else 0.0
237+
if (previousOneDecision == CodewhispererPreviousSuggestionState.Accept) {
238+
if (isExperimentGroup) {
239+
CodeWhispererClassifierConstants.prevDecisionAcceptCoefficientExp
240+
} else CodeWhispererClassifierConstants.prevDecisionAcceptCoefficient
241+
} else {
242+
0.0
243+
}
216244
previousOneReject =
217-
if (previousOneDecision == CodewhispererPreviousSuggestionState.Reject) CodeWhispererClassifierConstants.prevDecisionRejectCoefficient else 0.0
245+
if (previousOneDecision == CodewhispererPreviousSuggestionState.Reject) {
246+
if (isExperimentGroup) {
247+
CodeWhispererClassifierConstants.prevDecisionRejectCoefficientExp
248+
} else CodeWhispererClassifierConstants.prevDecisionRejectCoefficient
249+
} else {
250+
0.0
251+
}
218252
previousOneOther =
219253
if (
220254
previousOneDecision != CodewhispererPreviousSuggestionState.Accept &&
221255
previousOneDecision != CodewhispererPreviousSuggestionState.Reject
222256
) {
223-
CodeWhispererClassifierConstants.prevDecisionOtherCoefficient
257+
if (isExperimentGroup) {
258+
CodeWhispererClassifierConstants.prevDecisionOtherCoefficientExp
259+
} else CodeWhispererClassifierConstants.prevDecisionOtherCoefficient
224260
} else {
225261
0.0
226262
}
227263
}
228264

265+
var leftContextLengthCoefficient: Double = 0.0
266+
if (isExperimentGroup) {
267+
leftContextLengthCoefficient = when (leftContextLength) {
268+
in 0..4 -> CodeWhispererClassifierConstants.lengthLeft0To5Exp
269+
in 5..9 -> CodeWhispererClassifierConstants.lengthLeft5To10Exp
270+
in 10..19 -> CodeWhispererClassifierConstants.lengthLeft10To20Exp
271+
in 20..29 -> CodeWhispererClassifierConstants.lengthLeft20To30Exp
272+
in 30..39 -> CodeWhispererClassifierConstants.lengthLeft30To40Exp
273+
in 40..49 -> CodeWhispererClassifierConstants.lengthLeft40To50Exp
274+
else -> 0.0
275+
}
276+
}
277+
278+
val normalizedLengthOfRight = if (isExperimentGroup) {
279+
CodeWhispererClassifierConstants.lengthofRightCoefficientExp * VariableTypeNeedNormalize.LenRight.normalizeExp(lengthOfRight.toDouble())
280+
} else CodeWhispererClassifierConstants.lengthofRightCoefficient * VariableTypeNeedNormalize.LenRight.normalize(lengthOfRight.toDouble())
281+
282+
val normalizedLengthOfLeftCurrent = if (isExperimentGroup) {
283+
CodeWhispererClassifierConstants.lengthOfLeftCurrentCoefficientExp *
284+
VariableTypeNeedNormalize.LenLeftCur.normalizeExp(lengthOfLeftCurrent.toDouble())
285+
} else CodeWhispererClassifierConstants.lengthOfLeftCurrentCoefficient * VariableTypeNeedNormalize.LenLeftCur.normalize(lengthOfLeftCurrent.toDouble())
286+
287+
val normalizedLengthOfPrev = if (isExperimentGroup) {
288+
CodeWhispererClassifierConstants.lengthOfLeftPrevCoefficientExp * VariableTypeNeedNormalize.LenLeftPrev.normalizeExp(lengthOfLeftPrev)
289+
} else CodeWhispererClassifierConstants.lengthOfLeftPrevCoefficient * VariableTypeNeedNormalize.LenLeftPrev.normalize(lengthOfLeftPrev)
290+
291+
val normalizedLineNum = if (isExperimentGroup) {
292+
CodeWhispererClassifierConstants.lineNumCoefficientExp * VariableTypeNeedNormalize.LineNum.normalizeExp(caretPosition.line.toDouble())
293+
} else CodeWhispererClassifierConstants.lineNumCoefficient * VariableTypeNeedNormalize.LineNum.normalize(caretPosition.line.toDouble())
294+
295+
val normalizedCursor = if (isExperimentGroup) {
296+
0.0
297+
} else CodeWhispererClassifierConstants.cursorOffsetCoefficient * VariableTypeNeedNormalize.Cursor.normalize(caretPosition.offset.toDouble())
298+
299+
val normalizedLineDiff = if (isExperimentGroup) {
300+
0.0
301+
} else CodeWhispererClassifierConstants.lineDiffCoefficient * VariableTypeNeedNormalize.LineDiff.normalize(lineDiff)
302+
303+
val intercept = if (isExperimentGroup) CodeWhispererClassifierConstants.interceptExp else CodeWhispererClassifierConstants.intercept
304+
229305
val resultBeforeSigmoid =
230-
CodeWhispererClassifierConstants.lengthofRightCoefficient * VariableTypeNeedNormalize.LenRight.normalize(lengthOfRight.toDouble()) +
231-
CodeWhispererClassifierConstants.lengthOfLeftCurrentCoefficient *
232-
VariableTypeNeedNormalize.LenLeftCur.normalize(lengthOfLeftCurrent.toDouble()) +
233-
CodeWhispererClassifierConstants.lengthOfLeftPrevCoefficient * VariableTypeNeedNormalize.LenLeftPrev.normalize(lengthOfLeftPrev) +
234-
CodeWhispererClassifierConstants.lineNumCoefficient * VariableTypeNeedNormalize.LineNum.normalize(caretPosition.line.toDouble()) +
235-
CodeWhispererClassifierConstants.cursorOffsetCoefficient * VariableTypeNeedNormalize.Cursor.normalize(caretPosition.offset.toDouble()) +
236-
CodeWhispererClassifierConstants.lineDiffCoefficient * VariableTypeNeedNormalize.LineDiff.normalize(lineDiff) +
306+
normalizedLengthOfRight +
307+
normalizedLengthOfLeftCurrent +
308+
normalizedLengthOfPrev +
309+
normalizedLineNum +
310+
normalizedCursor +
311+
normalizedLineDiff +
237312
languageCoefficient +
238313
osCoefficient +
239314
triggerTypeCoefficient +
@@ -243,7 +318,8 @@ class CodeWhispererAutoTriggerService : CodeWhispererAutoTriggerHandler, Disposa
243318
previousOneAccept +
244319
previousOneReject +
245320
previousOneOther +
246-
CodeWhispererClassifierConstants.intercept
321+
leftContextLengthCoefficient +
322+
intercept
247323

248324
val shouldTrigger = sigmoid(resultBeforeSigmoid) > getThreshold()
249325
return ClassifierResult(shouldTrigger, sigmoid(resultBeforeSigmoid))
@@ -253,10 +329,15 @@ class CodeWhispererAutoTriggerService : CodeWhispererAutoTriggerHandler, Disposa
253329

254330
companion object {
255331
private const val triggerThreshold: Double = 0.4
332+
private const val triggerThresholdExp: Double = 0.43
256333

257334
fun getInstance(): CodeWhispererAutoTriggerService = service()
258335

259-
fun getThreshold(): Double = triggerThreshold
336+
fun getThreshold(): Double = if (CodeWhispererUserGroupSettings.getInstance().getUserGroup() == CodeWhispererUserGroup.Classifier) {
337+
triggerThresholdExp
338+
} else {
339+
triggerThreshold
340+
}
260341

261342
fun sigmoid(x: Double): Double = 1 / (1 + exp(-x))
262343
}
@@ -265,24 +346,31 @@ class CodeWhispererAutoTriggerService : CodeWhispererAutoTriggerHandler, Disposa
265346
private enum class VariableTypeNeedNormalize {
266347
Cursor {
267348
override fun normalize(value: Double): Double = (value - minn.cursor) / (maxx.cursor - minn.cursor)
349+
override fun normalizeExp(value: Double): Double = 0.0
268350
},
269351
LineNum {
270352
override fun normalize(value: Double): Double = (value - minn.lineNum) / (maxx.lineNum - minn.lineNum)
353+
override fun normalizeExp(value: Double): Double = (value - minnExp.lineNum) / (maxxExp.lineNum - minnExp.lineNum)
271354
},
272355
LenLeftCur {
273356
override fun normalize(value: Double): Double = (value - minn.lenLeftCur) / (maxx.lenLeftCur - minn.lenLeftCur)
357+
override fun normalizeExp(value: Double): Double = (value - minnExp.lenLeftCur) / (maxxExp.lenLeftCur - minnExp.lenLeftCur)
274358
},
275359
LenLeftPrev {
276360
override fun normalize(value: Double): Double = (value - minn.lenLeftPrev) / (maxx.lenLeftPrev - minn.lenLeftPrev)
361+
override fun normalizeExp(value: Double): Double = (value - minnExp.lenLeftPrev) / (maxxExp.lenLeftPrev - minnExp.lenLeftPrev)
277362
},
278363
LenRight {
279364
override fun normalize(value: Double): Double = (value - minn.lenRight) / (maxx.lenRight - minn.lenRight)
365+
override fun normalizeExp(value: Double): Double = (value - minnExp.lenRight) / (maxxExp.lenRight - minnExp.lenRight)
280366
},
281367
LineDiff {
282368
override fun normalize(value: Double): Double = (value - minn.lineDiff) / (maxx.lineDiff - minn.lineDiff)
369+
override fun normalizeExp(value: Double): Double = 0.0
283370
};
284371

285372
abstract fun normalize(value: Double): Double
373+
abstract fun normalizeExp(toDouble: Double): Double
286374

287375
data class NormalizedCoefficients(
288376
val cursor: Double,
@@ -293,6 +381,13 @@ private enum class VariableTypeNeedNormalize {
293381
val lineDiff: Double,
294382
)
295383

384+
data class NormalizedCoefficientsExp(
385+
val lineNum: Double,
386+
val lenLeftCur: Double,
387+
val lenLeftPrev: Double,
388+
val lenRight: Double,
389+
)
390+
296391
companion object {
297392
private val maxx = NormalizedCoefficients(
298393
cursor = 84716.0,
@@ -303,6 +398,13 @@ private enum class VariableTypeNeedNormalize {
303398
lineDiff = 270.0,
304399
)
305400

401+
private val maxxExp = NormalizedCoefficientsExp(
402+
lineNum = 4631.0,
403+
lenLeftCur = 157.0,
404+
lenLeftPrev = 176.0,
405+
lenRight = 10239.0,
406+
)
407+
306408
private val minn = NormalizedCoefficients(
307409
cursor = 1.0,
308410
lineNum = 0.0,
@@ -311,5 +413,12 @@ private enum class VariableTypeNeedNormalize {
311413
lenRight = 0.0,
312414
lineDiff = -28336.0,
313415
)
416+
417+
private val minnExp = NormalizedCoefficientsExp(
418+
lineNum = 0.0,
419+
lenLeftCur = 0.0,
420+
lenLeftPrev = 0.0,
421+
lenRight = 0.0,
422+
)
314423
}
315424
}

0 commit comments

Comments
 (0)