Skip to content

Commit 1077f09

Browse files
committed
feat(prompts): support amend commits (#230)
Move logic for building prompt from AICommitAction to LLMClientService in order to run the code in coroutines. Closes #230
1 parent 1189812 commit 1077f09

File tree

9 files changed

+73
-45
lines changed

9 files changed

+73
-45
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Added
66

77
- Option to choose prompt per project.
8+
- Amending commits now adds the changes from previous commit to the prompt.
89

910
### Fixed
1011

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
package com.github.blarc.ai.commits.intellij.plugin
22

3-
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.commonBranch
4-
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.computeDiff
5-
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.constructPrompt
63
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
74
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
85
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings2
@@ -17,42 +14,23 @@ class AICommitAction : AnAction(), DumbAware {
1714
override fun actionPerformed(e: AnActionEvent) {
1815
val llmClient = AppSettings2.instance.getActiveLLMClientConfiguration()
1916
if (llmClient == null) {
20-
Notification.clientNotSet()
17+
sendNotification(Notification.clientNotSet())
2118
return
2219
}
23-
val project = e.project ?: return
2420

2521
val commitWorkflowHandler = e.getData(VcsDataKeys.COMMIT_WORKFLOW_HANDLER) as AbstractCommitWorkflowHandler<*, *>?
2622
if (commitWorkflowHandler == null) {
2723
sendNotification(Notification.noCommitMessage())
2824
return
2925
}
3026

31-
val includedChanges = commitWorkflowHandler.ui.getIncludedChanges()
3227
val commitMessage = VcsDataKeys.COMMIT_MESSAGE_CONTROL.getData(e.dataContext) as CommitMessage?
33-
34-
val diff = computeDiff(includedChanges, false, project)
35-
if (diff.isBlank()) {
36-
sendNotification(Notification.emptyDiff())
37-
return
38-
}
39-
40-
val branch = commonBranch(includedChanges, project)
41-
val hint = commitMessage?.text
42-
43-
val prompt = constructPrompt(AppSettings2.instance.activePrompt.content, diff, branch, hint, project)
44-
45-
// TODO @Blarc: add support for different clients
46-
// if (isPromptTooLarge(prompt)) {
47-
// sendNotification(Notification.promptTooLarge())
48-
// return@runBackgroundableTask
49-
// }
50-
5128
if (commitMessage == null) {
5229
sendNotification(Notification.noCommitMessage())
5330
return
5431
}
5532

56-
llmClient.generateCommitMessage(prompt, project, commitMessage)
33+
val project = e.project ?: return
34+
llmClient.generateCommitMessage(commitWorkflowHandler, commitMessage, project)
5735
}
5836
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientConfiguration.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.intellij.openapi.project.Project
44
import com.intellij.openapi.ui.ComboBox
55
import com.intellij.openapi.vcs.ui.CommitMessage
66
import com.intellij.util.xmlb.annotations.Attribute
7+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
78
import java.util.*
89
import javax.swing.Icon
910

@@ -38,7 +39,7 @@ abstract class LLMClientConfiguration(
3839
getSharedState().modelIds.add(modelId)
3940
}
4041

41-
abstract fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage)
42+
abstract fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project)
4243

4344
abstract fun getRefreshModelsFunction(): ((ComboBox<String>) -> Unit)?
4445

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientService.kt

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,62 @@
11
package com.github.blarc.ai.commits.intellij.plugin.settings.clients
22

33
import com.github.blarc.ai.commits.intellij.plugin.AICommitsBundle.message
4+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.commonBranch
5+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.computeDiff
6+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.constructPrompt
7+
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
8+
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
49
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings2
510
import com.github.blarc.ai.commits.intellij.plugin.wrap
611
import com.intellij.icons.AllIcons
712
import com.intellij.openapi.application.EDT
813
import com.intellij.openapi.application.ModalityState
914
import com.intellij.openapi.application.asContextElement
1015
import com.intellij.openapi.project.Project
16+
import com.intellij.openapi.vcs.changes.Change
1117
import com.intellij.openapi.vcs.ui.CommitMessage
1218
import com.intellij.platform.ide.progress.withBackgroundProgress
1319
import com.intellij.ui.components.JBLabel
20+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
21+
import com.intellij.vcs.commit.isAmendCommitMode
1422
import dev.langchain4j.data.message.UserMessage
1523
import dev.langchain4j.model.chat.ChatLanguageModel
24+
import git4idea.GitCommit
25+
import git4idea.history.GitHistoryUtils
26+
import git4idea.repo.GitRepositoryManager
1627
import kotlinx.coroutines.CoroutineScope
1728
import kotlinx.coroutines.Dispatchers
1829
import kotlinx.coroutines.launch
1930
import kotlinx.coroutines.withContext
2031

21-
abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: CoroutineScope) {
32+
abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: CoroutineScope) {
2233

23-
abstract suspend fun buildChatModel(client: T): ChatLanguageModel
34+
abstract suspend fun buildChatModel(client: C): ChatLanguageModel
2435

25-
fun generateCommitMessage(client: T, prompt: String, project: Project, commitMessage: CommitMessage) {
26-
cs.launch(Dispatchers.IO + ModalityState.current().asContextElement()) {
36+
fun generateCommitMessage(clientConfiguration: C, commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
37+
38+
val commitContext = commitWorkflowHandler.workflow.commitContext
39+
val includedChanges = commitWorkflowHandler.ui.getIncludedChanges().toMutableList()
40+
41+
cs.launch(ModalityState.current().asContextElement()) {
2742
withBackgroundProgress(project, message("action.background")) {
28-
sendRequest(client, prompt, onSuccess = {
43+
44+
if (commitContext.isAmendCommitMode) {
45+
includedChanges += getLastCommitChanges(project)
46+
}
47+
48+
val diff = computeDiff(includedChanges, false, project)
49+
if (diff.isBlank()) {
50+
withContext(Dispatchers.EDT) {
51+
sendNotification(Notification.emptyDiff())
52+
}
53+
return@withBackgroundProgress
54+
}
55+
56+
val branch = commonBranch(includedChanges, project)
57+
val prompt = constructPrompt(AppSettings2.instance.activePrompt.content, diff, branch, commitMessage.text, project)
58+
59+
sendRequest(clientConfiguration, prompt, onSuccess = {
2960
withContext(Dispatchers.EDT) {
3061
commitMessage.setCommitMessage(it)
3162
}
@@ -39,9 +70,9 @@ abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: Coro
3970
}
4071
}
4172

42-
fun verifyConfiguration(client: T, label: JBLabel) {
73+
fun verifyConfiguration(client: C, label: JBLabel) {
4374
label.text = message("settings.verify.running")
44-
cs.launch(Dispatchers.IO + ModalityState.current().asContextElement()) {
75+
cs.launch(ModalityState.current().asContextElement()) {
4576
sendRequest(client, "test", onSuccess = {
4677
withContext(Dispatchers.EDT) {
4778
label.text = message("settings.verify.valid")
@@ -56,7 +87,7 @@ abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: Coro
5687
}
5788
}
5889

59-
private suspend fun sendRequest(client: T, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
90+
private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
6091
try {
6192
val model = buildChatModel(client)
6293
val response = withContext(Dispatchers.IO) {
@@ -78,4 +109,16 @@ abstract class LLMClientService<T : LLMClientConfiguration>(private val cs: Coro
78109
throw e
79110
}
80111
}
112+
113+
private suspend fun getLastCommitChanges(project: Project): List<Change> {
114+
return withContext(Dispatchers.IO) {
115+
GitRepositoryManager.getInstance(project).repositories.map { repo ->
116+
GitHistoryUtils.history(project, repo.root, "--max-count=1")
117+
}.filter { commits ->
118+
commits.isNotEmpty()
119+
}.map { commits ->
120+
(commits.first() as GitCommit).changes
121+
}.flatten()
122+
}
123+
}
81124
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/anthropic/AnthropicClientConfiguration.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.anthropic;
1+
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.anthropic
22

33
import com.github.blarc.ai.commits.intellij.plugin.Icons
44
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientConfiguration
@@ -7,6 +7,7 @@ import com.intellij.openapi.project.Project
77
import com.intellij.openapi.vcs.ui.CommitMessage
88
import com.intellij.util.xmlb.annotations.Attribute
99
import com.intellij.util.xmlb.annotations.Transient
10+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
1011
import dev.langchain4j.model.anthropic.AnthropicChatModelName
1112
import javax.swing.Icon
1213

@@ -44,8 +45,8 @@ class AnthropicClientConfiguration : LLMClientConfiguration(
4445
return AnthropicClientSharedState.getInstance()
4546
}
4647

47-
override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
48-
return AnthropicClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
48+
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
49+
return AnthropicClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
4950
}
5051

5152
override fun getRefreshModelsFunction() = null

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/gemini/GeminiClientConfiguration.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientSha
66
import com.intellij.openapi.project.Project
77
import com.intellij.openapi.vcs.ui.CommitMessage
88
import com.intellij.util.xmlb.annotations.Attribute
9+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
910
import javax.swing.Icon
1011

1112
class GeminiClientConfiguration : LLMClientConfiguration(
@@ -34,8 +35,8 @@ class GeminiClientConfiguration : LLMClientConfiguration(
3435
return GeminiClientSharedState.getInstance()
3536
}
3637

37-
override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
38-
return GeminiClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
38+
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
39+
return GeminiClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
3940
}
4041

4142
// Model names are hard-coded and do not need to be refreshed.

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/ollama/OllamaClientConfiguration.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import com.intellij.openapi.project.Project
77
import com.intellij.openapi.ui.ComboBox
88
import com.intellij.openapi.vcs.ui.CommitMessage
99
import com.intellij.util.xmlb.annotations.Attribute
10+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
1011
import javax.swing.Icon
1112

1213
class OllamaClientConfiguration : LLMClientConfiguration(
@@ -36,8 +37,8 @@ class OllamaClientConfiguration : LLMClientConfiguration(
3637
return OllamaClientSharedState.getInstance()
3738
}
3839

39-
override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
40-
return OllamaClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
40+
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
41+
return OllamaClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
4142
}
4243

4344
override fun getRefreshModelsFunction() = fun (cb: ComboBox<String>) {

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/openAi/OpenAiClientConfiguration.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import com.intellij.openapi.project.Project
66
import com.intellij.openapi.vcs.ui.CommitMessage
77
import com.intellij.util.xmlb.annotations.Attribute
88
import com.intellij.util.xmlb.annotations.Transient
9+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
910
import javax.swing.Icon
1011

1112
class OpenAiClientConfiguration : LLMClientConfiguration(
@@ -43,8 +44,8 @@ class OpenAiClientConfiguration : LLMClientConfiguration(
4344
return OpenAiClientSharedState.getInstance()
4445
}
4546

46-
override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
47-
return OpenAiClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
47+
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
48+
return OpenAiClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
4849
}
4950

5051
// Model names are retrieved from Enum and do not need to be refreshed.

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/qianfan/QianfanClientConfiguration.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import com.intellij.openapi.project.Project
66
import com.intellij.openapi.vcs.ui.CommitMessage
77
import com.intellij.util.xmlb.annotations.Attribute
88
import com.intellij.util.xmlb.annotations.Transient
9+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
910
import dev.langchain4j.model.qianfan.QianfanChatModelNameEnum
1011
import javax.swing.Icon
1112

@@ -41,8 +42,8 @@ class QianfanClientConfiguration : LLMClientConfiguration(
4142
return QianfanClientSharedState.getInstance()
4243
}
4344

44-
override fun generateCommitMessage(prompt: String, project: Project, commitMessage: CommitMessage) {
45-
return QianfanClientService.getInstance().generateCommitMessage(this, prompt, project, commitMessage)
45+
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
46+
return QianfanClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
4647
}
4748

4849
// Model names are retrieved from Enum and do not need to be refreshed.

0 commit comments

Comments
 (0)