Skip to content

Commit a0db29c

Browse files
committed
feat(clients): show progress for streaming response
1 parent 0772d43 commit a0db29c

File tree

1 file changed

+19
-14
lines changed
  • src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients

1 file changed

+19
-14
lines changed

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientService.kt

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,11 @@ import dev.langchain4j.data.message.UserMessage
2626
import dev.langchain4j.model.StreamingResponseHandler
2727
import dev.langchain4j.model.chat.ChatLanguageModel
2828
import dev.langchain4j.model.chat.StreamingChatLanguageModel
29+
import dev.langchain4j.model.output.Response
2930
import git4idea.GitCommit
3031
import git4idea.history.GitHistoryUtils
3132
import git4idea.repo.GitRepositoryManager
32-
import kotlinx.coroutines.CoroutineScope
33-
import kotlinx.coroutines.Dispatchers
34-
import kotlinx.coroutines.launch
35-
import kotlinx.coroutines.withContext
33+
import kotlinx.coroutines.*
3634

3735
abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: CoroutineScope) {
3836

@@ -81,7 +79,7 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
8179
fun verifyConfiguration(client: C, label: JBLabel) {
8280
label.text = message("settings.verify.running")
8381
cs.launch(ModalityState.current().asContextElement()) {
84-
sendRequest(client, "test", onSuccess = {
82+
makeRequest(client, "test", onSuccess = {
8583
withContext(Dispatchers.EDT) {
8684
label.text = message("settings.verify.valid")
8785
label.icon = AllIcons.General.InspectionsOK
@@ -99,11 +97,11 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
9997
try {
10098
if (AppSettings2.instance.useStreamingResponse) {
10199
buildStreamingChatModel(client)?.let { streamingChatModel ->
102-
sendStreamingRequest(streamingChatModel, text, onSuccess, onError)
100+
sendStreamingRequest(streamingChatModel, text, onSuccess)
103101
return
104102
}
105103
}
106-
sendRequest(client, text, onSuccess, onError)
104+
sendRequest(client, text, onSuccess)
107105
} catch (e: IllegalArgumentException) {
108106
onError(message("settings.verify.invalid", e.message ?: message("unknown-error")))
109107
} catch (e: Exception) {
@@ -113,8 +111,10 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
113111
}
114112
}
115113

116-
private suspend fun sendStreamingRequest(streamingModel: StreamingChatLanguageModel, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
114+
private suspend fun sendStreamingRequest(streamingModel: StreamingChatLanguageModel, text: String, onSuccess: suspend (r: String) -> Unit) {
117115
var response = ""
116+
val completionDeferred = CompletableDeferred<String>()
117+
118118
withContext(Dispatchers.IO) {
119119
streamingModel.generate(
120120
listOf(
@@ -131,18 +131,23 @@ abstract class LLMClientService<C : LLMClientConfiguration>(private val cs: Coro
131131
}
132132
}
133133

134-
override fun onError(error: Throwable?) {
135-
response = error?.message.toString()
136-
cs.launch {
137-
onError(response)
138-
}
134+
override fun onError(error: Throwable) {
135+
completionDeferred.completeExceptionally(error)
136+
}
137+
138+
override fun onComplete(response: Response<AiMessage>) {
139+
super.onComplete(response)
140+
completionDeferred.complete(response.content().text())
139141
}
140142
}
141143
)
144+
// This throws exception if completionDeferred.completeExceptionally(error) is called
145+
// which is handled by the function calling this function
146+
onSuccess(completionDeferred.await())
142147
}
143148
}
144149

145-
private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit, onError: suspend (r: String) -> Unit) {
150+
private suspend fun sendRequest(client: C, text: String, onSuccess: suspend (r: String) -> Unit) {
146151
val model = buildChatModel(client)
147152
val response = withContext(Dispatchers.IO) {
148153
model.generate(

0 commit comments

Comments
 (0)