Skip to content

Commit 70a43e7

Browse files
committed
feat(clients): replace openai-client with langchain4j
1 parent e7c11dc commit 70a43e7

File tree

10 files changed

+118
-126
lines changed

10 files changed

+118
-126
lines changed

build.gradle.kts

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ fun properties(key: String) = project.findProperty(key).toString()
66
plugins {
77
id("org.jetbrains.kotlin.jvm") version "2.0.0"
88
id("org.jetbrains.intellij") version "1.17.3"
9+
kotlin("plugin.serialization") version "1.9.23"
910

1011
// Gradle Changelog Plugin
1112
id("org.jetbrains.changelog") version "2.2.0"
@@ -102,18 +103,27 @@ tasks.test {
102103
}
103104

104105
dependencies {
105-
implementation("com.aallam.openai:openai-client:3.7.2") {
106-
exclude(group = "org.slf4j", module = "slf4j-api")
107-
// Prevents java.lang.LinkageError: java.lang.LinkageError: loader constraint violation:when resolving method 'long kotlin.time.Duration.toLong-impl(long, kotlin.time.DurationUnit)'
108-
exclude(group = "org.jetbrains.kotlin", module = "kotlin-stdlib")
109-
}
110-
implementation("io.ktor:ktor-client-cio:2.3.11") {
111-
exclude(group = "org.slf4j", module = "slf4j-api")
112-
// Prevents java.lang.LinkageError: java.lang.LinkageError: loader constraint violation: when resolving method 'long kotlin.time.Duration.toLong-impl(long, kotlin.time.DurationUnit)'
113-
exclude(group = "org.jetbrains.kotlin", module = "kotlin-stdlib")
114-
}
115-
116-
implementation("com.knuddels:jtokkit:1.0.0")
106+
// implementation("com.aallam.openai:openai-client:3.7.2") {
107+
// exclude(group = "org.slf4j", module = "slf4j-api")
108+
// // Prevents java.lang.LinkageError: java.lang.LinkageError: loader constraint violation:when resolving method 'long kotlin.time.Duration.toLong-impl(long, kotlin.time.DurationUnit)'
109+
// exclude(group = "org.jetbrains.kotlin", module = "kotlin-stdlib")
110+
// }
111+
// implementation("io.ktor:ktor-client-cio:2.3.11") {
112+
// exclude(group = "org.slf4j", module = "slf4j-api")
113+
// // Prevents java.lang.LinkageError: java.lang.LinkageError: loader constraint violation: when resolving method 'long kotlin.time.Duration.toLong-impl(long, kotlin.time.DurationUnit)'
114+
// exclude(group = "org.jetbrains.kotlin", module = "kotlin-stdlib")
115+
// }
116+
//
117+
// implementation("com.knuddels:jtokkit:1.0.0")
118+
119+
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.0")
120+
121+
// langchain4j integrations
122+
implementation("dev.langchain4j:langchain4j-open-ai:0.29.1")
123+
implementation("dev.langchain4j:langchain4j-ollama:0.29.1")
124+
// implementation("dev.langchain4j:langchain4j-hugging-face:0.28.0")
125+
// implementation("dev.langchain4j:langchain4j-milvus:0.28.0")
126+
// implementation("dev.langchain4j:langchain4j-local-ai:0.28.0")
117127

118128
// tests
119129
testImplementation("org.junit.jupiter:junit-jupiter-params:5.10.2")

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/AICommitAction.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ class AICommitAction : AnAction(), DumbAware {
5151
return@runBackgroundableTask
5252
}
5353

54-
val openAIService = OpenAIService.instance
54+
val llmClient = AppSettings2.instance.getActiveLLMClient()
5555
runBlocking(Dispatchers.Main) {
5656
try {
57-
val generatedCommitMessage = openAIService.generateCommitMessage(prompt)
57+
val generatedCommitMessage = llmClient.generateCommitMessage(prompt)
5858
commitMessage.setCommitMessage(generatedCommitMessage)
5959
AppSettings2.instance.recordHit()
6060
} catch (e: Exception) {

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/OpenAIService.kt

Lines changed: 0 additions & 29 deletions
This file was deleted.

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/AICommitsListCellRenderer.kt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.github.blarc.ai.commits.intellij.plugin.settings
22

3-
import com.aallam.openai.api.model.ModelId
43
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClient
54
import com.github.blarc.ai.commits.intellij.plugin.settings.prompts.Prompt
65
import java.awt.Component
@@ -26,10 +25,6 @@ class AICommitsListCellRenderer : DefaultListCellRenderer() {
2625
text = value.name
2726
}
2827

29-
is ModelId -> {
30-
text = value.id
31-
}
32-
3328
is LLMClient -> {
3429
text = value.displayName
3530
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/AppSettings.kt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.github.blarc.ai.commits.intellij.plugin.settings
22

3-
import com.aallam.openai.client.OpenAIHost
43
import com.github.blarc.ai.commits.intellij.plugin.settings.prompts.DefaultPrompts
54
import com.intellij.openapi.application.ApplicationManager
65
import com.intellij.openapi.components.PersistentStateComponent
@@ -29,8 +28,8 @@ class AppSettings : PersistentStateComponent<AppSettings> {
2928

3029
var requestSupport = true
3130
var lastVersion: String? = null
32-
var openAIHost = OpenAIHost.OpenAI.baseUrl
33-
var openAIHosts = mutableSetOf(OpenAIHost.OpenAI.baseUrl)
31+
var openAIHost = "https://api.openai.com/v1"
32+
var openAIHosts = mutableSetOf("https://api.openai.com/v1")
3433
var openAISocketTimeout = "30"
3534
var proxyUrl: String? = null
3635

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClient.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ abstract class LLMClient(
3030

3131
abstract suspend fun generateCommitMessage(prompt: String): String
3232

33-
abstract suspend fun refreshModels()
33+
abstract fun getRefreshModelFunction(): (suspend () -> Unit)?
3434

3535
public abstract override fun clone(): LLMClient
3636

@@ -39,6 +39,7 @@ abstract class LLMClient(
3939
newHost: String,
4040
newProxy: String?,
4141
newTimeout: String,
42+
newModelId: String,
4243
newToken: String
4344
)
4445

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientTable.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,23 @@ class LLMClientTable {
112112
}
113113

114114
override fun doOKAction() {
115-
(cardLayout.findComponentById(llmClient.displayName) as DialogPanel).apply()
115+
if (newLlmClient == null) {
116+
(cardLayout.findComponentById(llmClient.displayName) as DialogPanel).apply()
117+
}
116118
super.doOKAction()
117119
}
118120

119121
override fun createCenterPanel() = if (newLlmClient == null) {
120122
createCardSplitter()
121123
} else {
122124
llmClient.panel().create()
125+
}.apply {
126+
isResizable = false
123127
}
124128

125129
private fun getLlmClients(newLLMClient: LLMClient?): List<LLMClient> {
126130
return if (newLLMClient == null) {
127-
// TODO: Find a better way to create the list of all possible LLM Clients that implement LLMClient abstract class
131+
// TODO(@Blarc): Is there a better way to create the list of all possible LLM Clients that implement LLMClient abstract class
128132
listOf(
129133
OpenAIClient(),
130134
TestAIClient()
Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
package com.github.blarc.ai.commits.intellij.plugin.settings.clients
22

3-
import com.aallam.openai.api.chat.ChatCompletion
4-
import com.aallam.openai.api.chat.ChatCompletionRequest
5-
import com.aallam.openai.api.chat.ChatMessage
6-
import com.aallam.openai.api.chat.ChatRole
7-
import com.aallam.openai.api.http.Timeout
8-
import com.aallam.openai.api.model.ModelId
9-
import com.aallam.openai.client.OpenAI
10-
import com.aallam.openai.client.OpenAIConfig
11-
import com.aallam.openai.client.OpenAIHost
12-
import com.aallam.openai.client.ProxyConfig
133
import com.github.blarc.ai.commits.intellij.plugin.Icons
4+
import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings
145
import com.intellij.util.xmlb.annotations.Attribute
6+
import dev.langchain4j.data.message.UserMessage
7+
import dev.langchain4j.model.openai.OpenAiChatModel
8+
import dev.langchain4j.model.openai.OpenAiChatModelName
9+
import java.net.InetSocketAddress
10+
import java.net.Proxy
11+
import java.net.URI
12+
import java.time.Duration
1513
import javax.swing.Icon
16-
import kotlin.time.Duration.Companion.seconds
1714

1815
class OpenAIClient(displayName: String = "OpenAI") : LLMClient(
1916
displayName,
20-
OpenAIHost.OpenAI.baseUrl,
17+
"https://api.openai.com/v1",
2118
null,
2219
30,
2320
"gpt-3.5-turbo",
@@ -27,9 +24,12 @@ class OpenAIClient(displayName: String = "OpenAI") : LLMClient(
2724
companion object {
2825
// TODO @Blarc: Static fields probably can't be attributes...
2926
@Attribute
30-
val hosts = mutableSetOf(OpenAIHost.OpenAI.baseUrl)
27+
val hosts = mutableSetOf("https://api.openai.com/v1")
3128
@Attribute
32-
val modelIds = mutableSetOf("gpt-3.5-turbo", "gpt-4")
29+
val modelIds = OpenAiChatModelName.entries.stream()
30+
.map { it.toString() }
31+
.toList()
32+
.toMutableSet()
3333
}
3434

3535
override fun getIcon(): Icon {
@@ -47,33 +47,28 @@ class OpenAIClient(displayName: String = "OpenAI") : LLMClient(
4747
override suspend fun generateCommitMessage(
4848
prompt: String
4949
): String {
50+
val openAI = OpenAiChatModel.builder()
51+
.apiKey(token)
52+
.modelName(modelId)
53+
.temperature(temperature.toDouble())
54+
.timeout(Duration.ofSeconds(timeout.toLong()))
55+
.baseUrl(AppSettings.instance.openAIHost)
56+
.build()
5057

51-
val openAI = OpenAI(openAIConfig())
52-
val chatCompletionRequest = ChatCompletionRequest(
53-
ModelId(modelId),
58+
val response = openAI.generate(
5459
listOf(
55-
ChatMessage(
56-
role = ChatRole.User,
57-
content = prompt
60+
UserMessage.from(
61+
"user",
62+
prompt
5863
)
59-
),
60-
temperature = temperature.toDouble(),
61-
topP = 1.0,
62-
frequencyPenalty = 0.0,
63-
presencePenalty = 0.0,
64-
maxTokens = 200,
65-
n = 1
64+
)
6665
)
67-
68-
val completion: ChatCompletion = openAI.chatCompletion(chatCompletionRequest)
69-
return completion.choices[0].message.content ?: "API returned an empty response."
66+
return response.content().text()
7067
}
7168

72-
override suspend fun refreshModels() {
73-
val openAI = OpenAI(openAIConfig())
74-
openAI.models()
75-
.map { it.id.id }
76-
.forEach { modelIds.add(it) }
69+
override fun getRefreshModelFunction(): (suspend () -> Unit)? {
70+
// Model names are retrieved from Enum and do not need to be refreshed.
71+
return null
7772
}
7873

7974
override fun clone(): LLMClient {
@@ -86,32 +81,38 @@ class OpenAIClient(displayName: String = "OpenAI") : LLMClient(
8681
return copy
8782
}
8883

89-
@Throws(Exception::class)
9084
override suspend fun verifyConfiguration(
9185
newHost: String,
9286
newProxy: String?,
9387
newTimeout: String,
88+
newModelId: String,
9489
newToken: String
9590
) {
9691

97-
val newConfig = OpenAIConfig(
98-
newToken,
99-
host = newHost.takeIf { it.isNotBlank() }?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI,
100-
proxy = newProxy?.takeIf { it.isNotBlank() }?.let { ProxyConfig.Http(it) },
101-
timeout = Timeout(socket = newTimeout.toInt().seconds)
92+
val openAiBuilder = OpenAiChatModel.builder()
93+
.apiKey(newToken)
94+
.modelName(newModelId)
95+
.temperature(temperature.toDouble())
96+
.timeout(Duration.ofSeconds(newTimeout.toLong()))
97+
98+
newHost.takeIf { it.isNotBlank() }?.let { openAiBuilder.baseUrl(it) }
99+
newProxy?.takeIf { it.isNotBlank() }?.let {
100+
val uri = URI(it)
101+
openAiBuilder.proxy(Proxy(Proxy.Type.HTTP, InetSocketAddress(uri.host, uri.port)))
102+
}
103+
104+
val openAi = openAiBuilder.build()
105+
openAi.generate(
106+
listOf(
107+
UserMessage.from(
108+
"user",
109+
"t"
110+
)
111+
)
102112
)
103-
val openAI = OpenAI(newConfig)
104-
openAI.models()
105113
}
106114

107115
override fun panel(): LLMClientPanel {
108116
return OpenAIClientPanel(this)
109117
}
110-
111-
private fun openAIConfig() = OpenAIConfig(
112-
token,
113-
host = host.takeIf { it.isNotBlank() }?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI,
114-
proxy = proxyUrl?.takeIf { it.isNotBlank() }?.let { ProxyConfig.Http(it) },
115-
timeout = Timeout(socket = timeout.seconds)
116-
)
117118
}

0 commit comments

Comments
 (0)