Skip to content

Commit efd54b3

Browse files
committed
feat(prompts): add prompt preview
1 parent fc56be6 commit efd54b3

File tree

5 files changed

+159
-102
lines changed

5 files changed

+159
-102
lines changed
Lines changed: 7 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
11
package com.github.blarc.ai.commits.intellij.plugin
22

33
import com.github.blarc.ai.commits.intellij.plugin.AICommitsBundle.message
4+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.commonBranch
5+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.computeDiff
6+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.constructPrompt
7+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.isPromptTooLarge
48
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
59
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
610
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings
711
import com.intellij.openapi.actionSystem.AnAction
812
import com.intellij.openapi.actionSystem.AnActionEvent
9-
import com.intellij.openapi.diff.impl.patch.IdeaTextPatchBuilder
10-
import com.intellij.openapi.diff.impl.patch.UnifiedDiffWriter
1113
import com.intellij.openapi.progress.runBackgroundableTask
1214
import com.intellij.openapi.project.DumbAware
13-
import com.intellij.openapi.project.Project
1415
import com.intellij.openapi.vcs.VcsDataKeys
15-
import com.intellij.openapi.vcs.changes.Change
1616
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
1717
import com.knuddels.jtokkit.Encodings
1818
import com.knuddels.jtokkit.api.ModelType
19-
import git4idea.repo.GitRepositoryManager
2019
import kotlinx.coroutines.Dispatchers
2120
import kotlinx.coroutines.runBlocking
22-
import java.io.StringWriter
2321

2422
class AICommitAction : AnAction(), DumbAware {
2523
override fun actionPerformed(e: AnActionEvent) {
@@ -41,14 +39,8 @@ class AICommitAction : AnAction(), DumbAware {
4139
return@runBackgroundableTask
4240
}
4341

44-
var branch = commonBranch(includedChanges, project)
45-
if (branch == null) {
46-
sendNotification(Notification.noCommonBranch())
47-
// hardcoded fallback branch
48-
branch = "main"
49-
}
50-
51-
val prompt = AppSettings.instance.getPrompt(diff, branch)
42+
val branch = commonBranch(includedChanges, project)
43+
val prompt = constructPrompt(AppSettings.instance.currentPrompt.content, diff, branch)
5244
if (isPromptTooLarge(prompt)) {
5345
sendNotification(Notification.promptTooLarge())
5446
return@runBackgroundableTask
@@ -72,70 +64,4 @@ class AICommitAction : AnAction(), DumbAware {
7264
}
7365
}
7466
}
75-
76-
private fun computeDiff(
77-
includedChanges: List<Change>,
78-
project: Project
79-
): String {
80-
81-
val gitRepositoryManager = GitRepositoryManager.getInstance(project)
82-
83-
// go through included changes, create a map of repository to changes and discard nulls
84-
val changesByRepository = includedChanges
85-
.filter {
86-
it.virtualFile?.path?.let { path ->
87-
AICommitsUtils.isPathExcluded(path, project)
88-
} ?: false
89-
}
90-
.mapNotNull { change ->
91-
change.virtualFile?.let { file ->
92-
gitRepositoryManager.getRepositoryForFileQuick(
93-
file
94-
) to change
95-
}
96-
}
97-
.groupBy({ it.first }, { it.second })
98-
99-
100-
// compute diff for each repository
101-
return changesByRepository
102-
.map { (repository, changes) ->
103-
repository?.let {
104-
val filePatches = IdeaTextPatchBuilder.buildPatch(
105-
project,
106-
changes,
107-
repository.root.toNioPath(), false, true
108-
)
109-
110-
val stringWriter = StringWriter()
111-
stringWriter.write("Repository: ${repository.root.path}\n")
112-
UnifiedDiffWriter.write(project, filePatches, stringWriter, "\n", null)
113-
stringWriter.toString()
114-
}
115-
}
116-
.joinToString("\n")
117-
}
118-
119-
private fun isPromptTooLarge(prompt: String): Boolean {
120-
val registry = Encodings.newDefaultEncodingRegistry()
121-
122-
/*
123-
* Try to find the model type based on the model id by finding the longest matching model type
124-
* If no model type matches, let the request go through and let the OpenAI API handle it
125-
*/
126-
val modelType = ModelType.values()
127-
.filter { AppSettings.instance.openAIModelId.contains(it.name) }
128-
.maxByOrNull { it.name.length }
129-
?: return false
130-
131-
val encoding = registry.getEncoding(modelType.encodingType)
132-
return encoding.countTokens(prompt) > modelType.maxContextLength
133-
}
134-
135-
private fun commonBranch(changes: List<Change>, project: Project): String? {
136-
val repositoryManager = GitRepositoryManager.getInstance(project)
137-
return changes.map {
138-
repositoryManager.getRepositoryForFileQuick(it.virtualFile)?.currentBranchName
139-
}.groupingBy { it }.eachCount().maxByOrNull { it.value }?.key
140-
}
141-
}
67+
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/AICommitsUtils.kt

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
package com.github.blarc.ai.commits.intellij.plugin
22

3+
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
4+
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
35
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings
46
import com.github.blarc.ai.commits.intellij.plugin.settings.ProjectSettings
57
import com.intellij.openapi.components.service
8+
import com.intellij.openapi.diff.impl.patch.IdeaTextPatchBuilder
9+
import com.intellij.openapi.diff.impl.patch.UnifiedDiffWriter
610
import com.intellij.openapi.project.Project
11+
import com.intellij.openapi.vcs.changes.Change
12+
import com.knuddels.jtokkit.Encodings
13+
import com.knuddels.jtokkit.api.ModelType
14+
import git4idea.repo.GitRepositoryManager
15+
import java.io.StringWriter
716
import java.nio.file.FileSystems
817

918
object AICommitsUtils {
1019

1120
fun isPathExcluded(path: String, project: Project) : Boolean {
1221
return !AppSettings.instance.isPathExcluded(path) && !project.service<ProjectSettings>().isPathExcluded(path)
1322
}
23+
1424
fun matchesGlobs(text: String, globs: Set<String>): Boolean {
1525
val fileSystem = FileSystems.getDefault()
1626
for (globString in globs) {
@@ -21,4 +31,89 @@ object AICommitsUtils {
2131
}
2232
return false
2333
}
24-
}
34+
35+
fun constructPrompt(promptContent: String, diff: String, branch: String): String {
36+
var content = promptContent
37+
content = content.replace("{locale}", AppSettings.instance.locale.displayLanguage)
38+
content = content.replace("{branch}", branch)
39+
40+
return if (content.contains("{diff}")) {
41+
content.replace("{diff}", diff)
42+
} else {
43+
"$content\n$diff"
44+
}
45+
}
46+
47+
fun commonBranch(changes: List<Change>, project: Project): String {
48+
val repositoryManager = GitRepositoryManager.getInstance(project)
49+
var branch = changes.map {
50+
repositoryManager.getRepositoryForFileQuick(it.virtualFile)?.currentBranchName
51+
}.groupingBy { it }.eachCount().maxByOrNull { it.value }?.key
52+
53+
if (branch == null) {
54+
sendNotification(Notification.noCommonBranch())
55+
// hardcoded fallback branch
56+
branch = "main"
57+
}
58+
return branch
59+
}
60+
61+
fun computeDiff(
62+
includedChanges: List<Change>,
63+
project: Project
64+
): String {
65+
66+
val gitRepositoryManager = GitRepositoryManager.getInstance(project)
67+
68+
// go through included changes, create a map of repository to changes and discard nulls
69+
val changesByRepository = includedChanges
70+
.filter {
71+
it.virtualFile?.path?.let { path ->
72+
AICommitsUtils.isPathExcluded(path, project)
73+
} ?: false
74+
}
75+
.mapNotNull { change ->
76+
change.virtualFile?.let { file ->
77+
gitRepositoryManager.getRepositoryForFileQuick(
78+
file
79+
) to change
80+
}
81+
}
82+
.groupBy({ it.first }, { it.second })
83+
84+
85+
// compute diff for each repository
86+
return changesByRepository
87+
.map { (repository, changes) ->
88+
repository?.let {
89+
val filePatches = IdeaTextPatchBuilder.buildPatch(
90+
project,
91+
changes,
92+
repository.root.toNioPath(), false, true
93+
)
94+
95+
val stringWriter = StringWriter()
96+
stringWriter.write("Repository: ${repository.root.path}\n")
97+
UnifiedDiffWriter.write(project, filePatches, stringWriter, "\n", null)
98+
stringWriter.toString()
99+
}
100+
}
101+
.joinToString("\n")
102+
}
103+
104+
fun isPromptTooLarge(prompt: String): Boolean {
105+
val registry = Encodings.newDefaultEncodingRegistry()
106+
107+
/*
108+
* Try to find the model type based on the model id by finding the longest matching model type
109+
* If no model type matches, let the request go through and let the OpenAI API handle it
110+
*/
111+
val modelType = ModelType.entries
112+
.filter { AppSettings.instance.openAIModelId.contains(it.name) }
113+
.maxByOrNull { it.name.length }
114+
?: return false
115+
116+
val encoding = registry.getEncoding(modelType.encodingType)
117+
return encoding.countTokens(prompt) > modelType.maxContextLength
118+
}
119+
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/AppSettings.kt

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,6 @@ class AppSettings : PersistentStateComponent<AppSettings> {
5151
get() = ApplicationManager.getApplication().getService(AppSettings::class.java)
5252
}
5353

54-
fun getPrompt(diff: String, branch: String): String {
55-
var content = currentPrompt.content
56-
content = content.replace("{locale}", locale.displayLanguage)
57-
content = content.replace("{branch}", branch)
58-
59-
return if (content.contains("{diff}")) {
60-
content.replace("{diff}", diff)
61-
} else {
62-
"$content\n$diff"
63-
}
64-
}
65-
6654
fun saveOpenAIToken(token: String) {
6755
try {
6856
PasswordSafe.instance.setPassword(getCredentialAttributes(openAITokenTitle), token)

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/prompt/PromptTable.kt

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,33 @@ package com.github.blarc.ai.commits.intellij.plugin.settings.prompt
22

33
import ai.grazie.utils.applyIf
44
import com.github.blarc.ai.commits.intellij.plugin.AICommitsBundle.message
5+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils
6+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.commonBranch
7+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.computeDiff
8+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.isPromptTooLarge
59
import com.github.blarc.ai.commits.intellij.plugin.createColumn
610
import com.github.blarc.ai.commits.intellij.plugin.notBlank
711
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings
812
import com.github.blarc.ai.commits.intellij.plugin.unique
13+
import com.intellij.dvcs.repo.VcsRepositoryManager
14+
import com.intellij.ide.DataManager
15+
import com.intellij.openapi.actionSystem.CommonDataKeys
916
import com.intellij.openapi.ui.DialogWrapper
1017
import com.intellij.ui.components.JBTextArea
1118
import com.intellij.ui.components.JBTextField
1219
import com.intellij.ui.dsl.builder.Align
1320
import com.intellij.ui.dsl.builder.bindText
1421
import com.intellij.ui.dsl.builder.panel
1522
import com.intellij.ui.table.TableView
23+
import com.intellij.ui.util.minimumWidth
24+
import com.intellij.ui.util.preferredHeight
25+
import com.intellij.ui.util.preferredWidth
1626
import com.intellij.util.ui.ListTableModel
27+
import git4idea.branch.GitBranchWorker
1728
import java.awt.event.MouseAdapter
1829
import java.awt.event.MouseEvent
1930
import javax.swing.ListSelectionModel.SINGLE_SELECTION
31+
import kotlin.math.max
2032

2133
class PromptTable {
2234
private var prompts = AppSettings.instance.prompts
@@ -98,14 +110,18 @@ class PromptTable {
98110
val promptNameTextField = JBTextField()
99111
val promptDescriptionTextField = JBTextField()
100112
val promptContentTextArea = JBTextArea()
113+
val promptPreviewTextArea = JBTextArea()
114+
lateinit var branch: String
115+
lateinit var diff: String
101116

102117
init {
103118
title = newPrompt?.let { message("settings.prompt.edit.title") } ?: message("settings.prompt.add.title")
104119
setOKButtonText(newPrompt?.let { message("actions.update") } ?: message("actions.add"))
105-
setSize(700, 500)
106120

107121
promptContentTextArea.wrapStyleWord = true
108122
promptContentTextArea.lineWrap = true
123+
promptContentTextArea.rows = 15
124+
promptContentTextArea.autoscrolls = false
109125

110126
if (!prompt.canBeChanged) {
111127
isOKActionEnabled = false
@@ -114,6 +130,25 @@ class PromptTable {
114130
promptContentTextArea.isEditable = false
115131
}
116132

133+
promptPreviewTextArea.wrapStyleWord = true
134+
promptPreviewTextArea.lineWrap = true
135+
promptPreviewTextArea.isEditable = false
136+
promptPreviewTextArea.rows = 25
137+
promptPreviewTextArea.columns = 100
138+
promptPreviewTextArea.autoscrolls = false
139+
140+
DataManager.getInstance().dataContextFromFocusAsync.onSuccess {
141+
val project = it.getData(CommonDataKeys.PROJECT)
142+
val changes = VcsRepositoryManager.getInstance(project!!).repositories.stream()
143+
.map { r -> GitBranchWorker.loadTotalDiff(r, r.currentBranchName!!) }
144+
.flatMap { r -> r.stream() }
145+
.toList()
146+
147+
branch = commonBranch(changes, project)
148+
diff = computeDiff(changes, project)
149+
setPreview(prompt.content)
150+
}
151+
117152
init()
118153
}
119154

@@ -135,17 +170,30 @@ class PromptTable {
135170
row {
136171
label(message("settings.prompt.content"))
137172
}
138-
row() {
139-
cell(promptContentTextArea)
140-
.align(Align.FILL)
173+
row {
174+
scrollCell(promptContentTextArea)
141175
.bindText(prompt::content)
142176
.validationOnApply { notBlank(it.text) }
143-
.resizableColumn()
144-
}.resizableRow()
177+
.onChanged { setPreview(it.text)}
178+
.align(Align.FILL)
179+
}
180+
row {
181+
label("Preview")
182+
}
183+
row {
184+
scrollCell(promptPreviewTextArea)
185+
.align(Align.FILL)
186+
}
145187
row {
146188
comment(message("settings.prompt.comment"))
147189
}
148190
}
149191

192+
private fun setPreview(promptContent: String) {
193+
val constructPrompt = AICommitsUtils.constructPrompt(promptContent, diff, branch)
194+
promptPreviewTextArea.text = constructPrompt.substring(0, constructPrompt.length.coerceAtMost(10000))
195+
promptPreviewTextArea.caretPosition = max(0, promptContentTextArea.caretPosition - 10)
196+
}
197+
150198
}
151-
}
199+
}

src/main/resources/messages/MyBundle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ settings.prompt.content=Content
4646
validation.required=This value is required.
4747
validation.number=Value is not a number.
4848
validation.temperature=Temperature should be between 0 and 2.
49-
settings.prompt.comment=You can use variables {locale}, {diff} and {branch} to customise your prompt.
49+
settings.prompt.comment=You can use variables {locale}, {diff} and {branch} to customise your prompt. Prompt preview shows only the first 10000 characters.
5050
actions.update=Update
5151
actions.add=Add
5252
settings.prompt.edit.title=Edit Prompt

0 commit comments

Comments
 (0)