Skip to content

Commit b8e2e61

Browse files
committed
feat(core): Add back prompt size checking.
1 parent be0fab1 commit b8e2e61

File tree

3 files changed

+52
-26
lines changed

3 files changed

+52
-26
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Changelog
22

33
## [Unreleased]
4+
### Changed
5+
- Use jtokkit library for getting max content length for a model and check if prompt is too large.
46

57
## [0.8.0] - 2023-04-14
68

build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,6 @@ dependencies {
114114
exclude(group = "org.jetbrains.kotlin", module = "kotlin-stdlib-jdk7")
115115
exclude(group = "org.jetbrains.kotlin", module = "kotlin-stdlib-jdk8")
116116
}
117+
118+
implementation("com.knuddels:jtokkit:0.3.0")
117119
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/AICommitAction.kt

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import com.intellij.openapi.project.Project
1414
import com.intellij.openapi.vcs.VcsDataKeys
1515
import com.intellij.openapi.vcs.changes.Change
1616
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
17+
import com.knuddels.jtokkit.Encodings
18+
import com.knuddels.jtokkit.api.ModelType
1719
import git4idea.repo.GitRepositoryManager
1820
import kotlinx.coroutines.Dispatchers
1921
import kotlinx.coroutines.runBlocking
@@ -40,20 +42,24 @@ class AICommitAction : AnAction(), DumbAware {
4042
return@runBackgroundableTask
4143
}
4244

45+
val prompt = AppSettings.instance.getPrompt(diff)
46+
if (isPromptTooLarge(prompt)) {
47+
sendNotification(Notification.promptTooLarge())
48+
return@runBackgroundableTask
49+
}
50+
4351
if (commitMessage == null) {
4452
sendNotification(Notification.noCommitMessage())
4553
return@runBackgroundableTask
4654
}
4755

4856
val openAIService = OpenAIService.instance
49-
val prompt = AppSettings.instance.getPrompt(diff)
5057
runBlocking(Dispatchers.Main) {
5158
try {
5259
val generatedCommitMessage = openAIService.generateCommitMessage(prompt, 1)
5360
commitMessage.setCommitMessage(generatedCommitMessage)
5461
AppSettings.instance.recordHit()
55-
}
56-
catch (e: Exception) {
62+
} catch (e: Exception) {
5763
commitMessage.setCommitMessage(message("action.error"))
5864
sendNotification(Notification.unsuccessfulRequest(e.message ?: message("action.unknown-error")))
5965
}
@@ -62,40 +68,56 @@ class AICommitAction : AnAction(), DumbAware {
6268
}
6369

6470
private fun computeDiff(
65-
includedChanges: List<Change>,
66-
project: Project
71+
includedChanges: List<Change>,
72+
project: Project
6773
): String {
6874

6975
val gitRepositoryManager = GitRepositoryManager.getInstance(project)
7076

7177
// go through included changes, create a map of repository to changes and discard nulls
7278
val changesByRepository = includedChanges
73-
.mapNotNull { change ->
74-
change.virtualFile?.let { file ->
75-
gitRepositoryManager.getRepositoryForFileQuick(
76-
file
77-
) to change
79+
.mapNotNull { change ->
80+
change.virtualFile?.let { file ->
81+
gitRepositoryManager.getRepositoryForFileQuick(
82+
file
83+
) to change
84+
}
7885
}
79-
}
80-
.groupBy({ it.first }, { it.second })
86+
.groupBy({ it.first }, { it.second })
8187

8288

8389
// compute diff for each repository
8490
return changesByRepository
85-
.map { (repository, changes) ->
86-
repository?.let {
87-
val filePatches = IdeaTextPatchBuilder.buildPatch(
88-
project,
89-
changes,
90-
repository.root.toNioPath(), false, true
91-
)
92-
93-
val stringWriter = StringWriter()
94-
stringWriter.write("Repository: ${repository.root.path}\n")
95-
UnifiedDiffWriter.write(project, filePatches, stringWriter, "\n", null)
96-
stringWriter.toString()
91+
.map { (repository, changes) ->
92+
repository?.let {
93+
val filePatches = IdeaTextPatchBuilder.buildPatch(
94+
project,
95+
changes,
96+
repository.root.toNioPath(), false, true
97+
)
98+
99+
val stringWriter = StringWriter()
100+
stringWriter.write("Repository: ${repository.root.path}\n")
101+
UnifiedDiffWriter.write(project, filePatches, stringWriter, "\n", null)
102+
stringWriter.toString()
103+
}
97104
}
98-
}
99-
.joinToString("\n")
105+
.joinToString("\n")
106+
}
107+
108+
private fun isPromptTooLarge(prompt: String): Boolean {
109+
val registry = Encodings.newDefaultEncodingRegistry()
110+
111+
/*
112+
* Try to find the model type based on the model id by finding the longest matching model type
113+
* If no model type matches, let the request go through and let the OpenAI API handle it
114+
*/
115+
val modelType = ModelType.values()
116+
.filter { AppSettings.instance.openAIModelId.contains(it.name) }
117+
.maxByOrNull { it.name.length }
118+
?: return false
119+
120+
val encoding = registry.getEncoding(modelType.encodingType)
121+
return encoding.countTokens(prompt) > modelType.maxContextLength
100122
}
101123
}

0 commit comments

Comments
 (0)