Skip to content

Commit fa2ad4a

Browse files
committed
feat(huggingface): add support for HuggingFace (#256)
Closes #256
1 parent 06cde93 commit fa2ad4a

File tree

12 files changed

+266
-2
lines changed

12 files changed

+266
-2
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased]
44

5+
### Added
6+
7+
- Support for Hugging Face.
8+
59
## [2.5.0] - 2024-09-22
610

711
### Added

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ plugin and configure a LLM API client in plugin's settings: <kbd>Settings</kbd>
3939
- Anthropic
4040
- Azure Open AI
4141
- Gemini
42+
- Hugging Face
4243
- Open AI
4344
- Ollama
4445
- Qianfan (Ernie)

build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ dependencies {
108108
implementation("dev.langchain4j:langchain4j-vertex-ai-gemini")
109109
implementation("dev.langchain4j:langchain4j-anthropic")
110110
implementation("dev.langchain4j:langchain4j-azure-open-ai")
111+
implementation("dev.langchain4j:langchain4j-hugging-face")
111112

112113
// tests
113114
testImplementation("org.junit.jupiter:junit-jupiter-params:5.11.2")

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/Icons.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ object Icons {
1010
val GEMINI = IconLoader.getIcon("/icons/gemini.png", javaClass)
1111
val ANTHROPIC = IconLoader.getIcon("/icons/anthropic.svg", javaClass)
1212
val AZURE_OPEN_AI = IconLoader.getIcon("/icons/azureOpenAi.svg", javaClass)
13+
val HUGGING_FACE = IconLoader.getIcon("/icons/huggingface.svg", javaClass)
1314
}

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/AppSettings2.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientCon
88
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.anthropic.AnthropicClientConfiguration
99
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.azureOpenAi.AzureOpenAiClientConfiguration
1010
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.gemini.GeminiClientConfiguration
11+
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.huggingface.HuggingFaceClientConfiguration
1112
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.ollama.OllamaClientConfiguration
1213
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.openAi.OpenAiClientConfiguration
1314
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.openAi.OpenAiClientSharedState
@@ -57,7 +58,8 @@ class AppSettings2 : PersistentStateComponent<AppSettings2> {
5758
QianfanClientConfiguration::class,
5859
GeminiClientConfiguration::class,
5960
AnthropicClientConfiguration::class,
60-
AzureOpenAiClientConfiguration::class
61+
AzureOpenAiClientConfiguration::class,
62+
HuggingFaceClientConfiguration::class
6163
],
6264
style = XCollection.Style.v2
6365
)

src/main/kotlin/com/github/blarc/ai/commits/intellij/plugin/settings/clients/LLMClientTable.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import com.github.blarc.ai.commits.intellij.plugin.settings.AppSettings2
66
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.anthropic.AnthropicClientConfiguration
77
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.azureOpenAi.AzureOpenAiClientConfiguration
88
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.gemini.GeminiClientConfiguration
9+
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.huggingface.HuggingFaceClientConfiguration
910
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.ollama.OllamaClientConfiguration
1011
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.openAi.OpenAiClientConfiguration
1112
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.qianfan.QianfanClientConfiguration
@@ -149,7 +150,8 @@ class LLMClientTable {
149150
QianfanClientConfiguration(),
150151
GeminiClientConfiguration(),
151152
AnthropicClientConfiguration(),
152-
AzureOpenAiClientConfiguration()
153+
AzureOpenAiClientConfiguration(),
154+
HuggingFaceClientConfiguration()
153155
)
154156
} else {
155157
listOf(newLLMClientConfiguration)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.huggingface;
2+
3+
import com.github.blarc.ai.commits.intellij.plugin.Icons
4+
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientConfiguration
5+
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientSharedState
6+
import com.intellij.openapi.project.Project
7+
import com.intellij.openapi.vcs.ui.CommitMessage
8+
import com.intellij.util.xmlb.annotations.Attribute
9+
import com.intellij.util.xmlb.annotations.Transient
10+
import com.intellij.vcs.commit.AbstractCommitWorkflowHandler
11+
import dev.langchain4j.model.huggingface.HuggingFaceModelName
12+
import javax.swing.Icon
13+
14+
class HuggingFaceClientConfiguration : LLMClientConfiguration(
15+
"HuggingFace",
16+
HuggingFaceModelName.TII_UAE_FALCON_7B_INSTRUCT,
17+
"0.7"
18+
) {
19+
20+
@Attribute
21+
var timeout: Int = 30
22+
@Attribute
23+
var maxNewTokens: Int = 30
24+
@Attribute
25+
var waitForModel: Boolean = true
26+
@Attribute
27+
var tokenIsStored: Boolean = false
28+
@Transient
29+
var token: String? = null
30+
31+
companion object {
32+
const val CLIENT_NAME = "HuggingFace"
33+
}
34+
35+
override fun getClientName(): String {
36+
return CLIENT_NAME
37+
}
38+
39+
override fun getClientIcon(): Icon {
40+
return Icons.HUGGING_FACE
41+
}
42+
43+
override fun getSharedState(): LLMClientSharedState {
44+
return HuggingFaceClientSharedState.getInstance()
45+
}
46+
47+
override fun generateCommitMessage(commitWorkflowHandler: AbstractCommitWorkflowHandler<*, *>, commitMessage: CommitMessage, project: Project) {
48+
return HuggingFaceClientService.getInstance().generateCommitMessage(this, commitWorkflowHandler, commitMessage, project)
49+
50+
}
51+
52+
override fun getRefreshModelsFunction() = null
53+
54+
override fun clone(): LLMClientConfiguration {
55+
val copy = HuggingFaceClientConfiguration()
56+
copy.id = id
57+
copy.name = name
58+
copy.modelId = modelId
59+
copy.temperature = temperature
60+
copy.tokenIsStored = tokenIsStored
61+
copy.timeout = timeout
62+
copy.waitForModel = waitForModel
63+
copy.maxNewTokens = maxNewTokens
64+
return copy
65+
}
66+
67+
override fun panel() = HuggingFaceClientPanel(this)
68+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.huggingface;
2+
3+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsBundle.message
4+
import com.github.blarc.ai.commits.intellij.plugin.emptyText
5+
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientPanel
6+
import com.intellij.ui.components.JBCheckBox
7+
import com.intellij.ui.components.JBPasswordField
8+
import com.intellij.ui.components.JBTextField
9+
import com.intellij.ui.dsl.builder.*
10+
11+
class HuggingFaceClientPanel private constructor(
12+
private val clientConfiguration: HuggingFaceClientConfiguration,
13+
val service: HuggingFaceClientService
14+
) : LLMClientPanel(clientConfiguration) {
15+
16+
private val tokenPasswordField = JBPasswordField()
17+
private val maxNewTokensTextField = JBTextField()
18+
private val waitForModelCheckBox = JBCheckBox()
19+
20+
constructor(configuration: HuggingFaceClientConfiguration) : this(configuration, HuggingFaceClientService.getInstance())
21+
22+
override fun create() = panel {
23+
nameRow()
24+
timeoutRow(clientConfiguration::timeout)
25+
tokenRow()
26+
modelIdRow()
27+
temperatureRow()
28+
maxNewTokens()
29+
waitForModel()
30+
verifyRow()
31+
}
32+
33+
override fun verifyConfiguration() {
34+
// Configuration passed to panel is already a copy of the original or a new configuration
35+
clientConfiguration.modelId = modelComboBox.item
36+
clientConfiguration.temperature = temperatureTextField.text
37+
clientConfiguration.timeout = socketTimeoutTextField.text.toInt()
38+
clientConfiguration.modelId = modelComboBox.item
39+
clientConfiguration.temperature = temperatureTextField.text
40+
clientConfiguration.token = String(tokenPasswordField.password)
41+
clientConfiguration.maxNewTokens = maxNewTokensTextField.text.toInt()
42+
clientConfiguration.waitForModel = waitForModelCheckBox.isSelected
43+
service.verifyConfiguration(clientConfiguration, verifyLabel)
44+
}
45+
46+
private fun Panel.tokenRow() {
47+
row {
48+
label(message("settings.llmClient.token"))
49+
.widthGroup("label")
50+
cell(tokenPasswordField)
51+
.bindText(getter = { "" }, setter = {
52+
HuggingFaceClientService.getInstance().saveToken(clientConfiguration, it)
53+
})
54+
.emptyText(if (clientConfiguration.tokenIsStored) message("settings.llmClient.token.stored") else message("settings.huggingface.token.example"))
55+
.resizableColumn()
56+
.align(Align.FILL)
57+
// maxLineLength was eye-balled, but prevents the dialog getting wider
58+
.comment(message("settings.huggingface.token.comment"), 50)
59+
}
60+
}
61+
62+
private fun Panel.maxNewTokens() {
63+
row {
64+
label(message("settings.huggingface.maxNewTokens"))
65+
.widthGroup("label")
66+
cell(maxNewTokensTextField)
67+
.bindIntText(clientConfiguration::maxNewTokens)
68+
.resizableColumn()
69+
.align(Align.FILL)
70+
}
71+
}
72+
73+
private fun Panel.waitForModel() {
74+
row {
75+
label(message("settings.huggingface.waitForModel"))
76+
.widthGroup("label")
77+
cell(waitForModelCheckBox)
78+
.bindSelected(clientConfiguration::waitForModel)
79+
.resizableColumn()
80+
.align(Align.FILL)
81+
82+
contextHelp(message("settings.huggingface.waitModel.comment"))
83+
.align(AlignX.RIGHT)
84+
}
85+
}
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.huggingface;
2+
3+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.getCredentialAttributes
4+
import com.github.blarc.ai.commits.intellij.plugin.AICommitsUtils.retrieveToken
5+
import com.github.blarc.ai.commits.intellij.plugin.notifications.Notification
6+
import com.github.blarc.ai.commits.intellij.plugin.notifications.sendNotification
7+
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientService
8+
import com.intellij.ide.passwordSafe.PasswordSafe
9+
import com.intellij.openapi.components.Service
10+
import com.intellij.openapi.components.service
11+
import com.intellij.util.text.nullize
12+
import dev.langchain4j.model.chat.ChatLanguageModel
13+
import dev.langchain4j.model.huggingface.HuggingFaceChatModel
14+
import kotlinx.coroutines.CoroutineScope
15+
import kotlinx.coroutines.Dispatchers
16+
import kotlinx.coroutines.launch
17+
import java.time.Duration
18+
19+
@Service(Service.Level.APP)
20+
class HuggingFaceClientService(private val cs: CoroutineScope) : LLMClientService<HuggingFaceClientConfiguration>(cs) {
21+
22+
companion object {
23+
@JvmStatic
24+
fun getInstance(): HuggingFaceClientService = service()
25+
}
26+
27+
override suspend fun buildChatModel(client: HuggingFaceClientConfiguration): ChatLanguageModel {
28+
val token = client.token.nullize(true) ?: retrieveToken(client.id)?.toString(true)
29+
30+
return HuggingFaceChatModel.builder()
31+
.accessToken(token)
32+
.modelId(client.modelId)
33+
.temperature(client.temperature.toDouble())
34+
.timeout(Duration.ofSeconds(client.timeout.toLong()))
35+
.maxNewTokens(client.maxNewTokens)
36+
.waitForModel(client.waitForModel)
37+
.build()
38+
}
39+
40+
fun saveToken(client: HuggingFaceClientConfiguration, token: String) {
41+
cs.launch(Dispatchers.Default) {
42+
try {
43+
PasswordSafe.instance.setPassword(getCredentialAttributes(client.id), token)
44+
client.tokenIsStored = true
45+
} catch (e: Exception) {
46+
sendNotification(Notification.unableToSaveToken(e.message))
47+
}
48+
}
49+
}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package com.github.blarc.ai.commits.intellij.plugin.settings.clients.huggingface;
2+
3+
import com.github.blarc.ai.commits.intellij.plugin.settings.clients.LLMClientSharedState
4+
import com.intellij.openapi.components.*
5+
import com.intellij.util.xmlb.annotations.XCollection
6+
import dev.langchain4j.model.huggingface.HuggingFaceModelName
7+
8+
@Service(Service.Level.APP)
9+
@State(name = "HuggingFaceClientSharedState", storages = [Storage("AICommitsHuggingFace.xml")])
10+
class HuggingFaceClientSharedState : PersistentStateComponent<HuggingFaceClientSharedState>, LLMClientSharedState {
11+
12+
companion object {
13+
@JvmStatic
14+
fun getInstance(): HuggingFaceClientSharedState = service()
15+
}
16+
17+
@XCollection(style = XCollection.Style.v2)
18+
override val hosts: MutableSet<String> = mutableSetOf()
19+
20+
@XCollection(style = XCollection.Style.v2)
21+
override val modelIds: MutableSet<String> = mutableSetOf(
22+
HuggingFaceModelName.TII_UAE_FALCON_7B_INSTRUCT
23+
)
24+
25+
override fun getState(): HuggingFaceClientSharedState = this
26+
27+
override fun loadState(state: HuggingFaceClientSharedState) {
28+
// Add all model IDs from enum in case they are not stored in xml
29+
modelIds += state.modelIds
30+
hosts += state.hosts
31+
}
32+
}

0 commit comments

Comments
 (0)