@@ -14,6 +14,8 @@ import com.intellij.openapi.project.Project
1414import com.intellij.openapi.vcs.VcsDataKeys
1515import com.intellij.openapi.vcs.changes.Change
1616import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
17+ import com.knuddels.jtokkit.Encodings
18+ import com.knuddels.jtokkit.api.ModelType
1719import git4idea.repo.GitRepositoryManager
1820import kotlinx.coroutines.Dispatchers
1921import kotlinx.coroutines.runBlocking
@@ -40,20 +42,24 @@ class AICommitAction : AnAction(), DumbAware {
4042 return @runBackgroundableTask
4143 }
4244
45+ val prompt = AppSettings .instance.getPrompt(diff)
46+ if (isPromptTooLarge(prompt)) {
47+ sendNotification(Notification .promptTooLarge())
48+ return @runBackgroundableTask
49+ }
50+
4351 if (commitMessage == null ) {
4452 sendNotification(Notification .noCommitMessage())
4553 return @runBackgroundableTask
4654 }
4755
4856 val openAIService = OpenAIService .instance
49- val prompt = AppSettings .instance.getPrompt(diff)
5057 runBlocking(Dispatchers .Main ) {
5158 try {
5259 val generatedCommitMessage = openAIService.generateCommitMessage(prompt, 1 )
5360 commitMessage.setCommitMessage(generatedCommitMessage)
5461 AppSettings .instance.recordHit()
55- }
56- catch (e: Exception ) {
62+ } catch (e: Exception ) {
5763 commitMessage.setCommitMessage(message(" action.error" ))
5864 sendNotification(Notification .unsuccessfulRequest(e.message ? : message(" action.unknown-error" )))
5965 }
@@ -62,40 +68,56 @@ class AICommitAction : AnAction(), DumbAware {
6268 }
6369
6470 private fun computeDiff (
65- includedChanges : List <Change >,
66- project : Project
71+ includedChanges : List <Change >,
72+ project : Project
6773 ): String {
6874
6975 val gitRepositoryManager = GitRepositoryManager .getInstance(project)
7076
7177 // go through included changes, create a map of repository to changes and discard nulls
7278 val changesByRepository = includedChanges
73- .mapNotNull { change ->
74- change.virtualFile?.let { file ->
75- gitRepositoryManager.getRepositoryForFileQuick(
76- file
77- ) to change
79+ .mapNotNull { change ->
80+ change.virtualFile?.let { file ->
81+ gitRepositoryManager.getRepositoryForFileQuick(
82+ file
83+ ) to change
84+ }
7885 }
79- }
80- .groupBy({ it.first }, { it.second })
86+ .groupBy({ it.first }, { it.second })
8187
8288
8389 // compute diff for each repository
8490 return changesByRepository
85- .map { (repository, changes) ->
86- repository?.let {
87- val filePatches = IdeaTextPatchBuilder .buildPatch(
88- project,
89- changes,
90- repository.root.toNioPath(), false , true
91- )
92-
93- val stringWriter = StringWriter ()
94- stringWriter.write(" Repository: ${repository.root.path} \n " )
95- UnifiedDiffWriter .write(project, filePatches, stringWriter, " \n " , null )
96- stringWriter.toString()
91+ .map { (repository, changes) ->
92+ repository?.let {
93+ val filePatches = IdeaTextPatchBuilder .buildPatch(
94+ project,
95+ changes,
96+ repository.root.toNioPath(), false , true
97+ )
98+
99+ val stringWriter = StringWriter ()
100+ stringWriter.write(" Repository: ${repository.root.path} \n " )
101+ UnifiedDiffWriter .write(project, filePatches, stringWriter, " \n " , null )
102+ stringWriter.toString()
103+ }
97104 }
98- }
99- .joinToString(" \n " )
105+ .joinToString(" \n " )
106+ }
107+
108+ private fun isPromptTooLarge (prompt : String ): Boolean {
109+ val registry = Encodings .newDefaultEncodingRegistry()
110+
111+ /*
112+ * Try to find the model type based on the model id by finding the longest matching model type
113+ * If no model type matches, let the request go through and let the OpenAI API handle it
114+ */
115+ val modelType = ModelType .values()
116+ .filter { AppSettings .instance.openAIModelId.contains(it.name) }
117+ .maxByOrNull { it.name.length }
118+ ? : return false
119+
120+ val encoding = registry.getEncoding(modelType.encodingType)
121+ return encoding.countTokens(prompt) > modelType.maxContextLength
100122 }
101123}
0 commit comments