Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@

import com.airhacks.afterburner.views.ViewLoader;
import com.google.common.annotations.VisibleForTesting;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;
import org.jabref.logic.ai.framework.messages.LlmMessage;
import org.jabref.logic.ai.framework.messages.UserMessage;
import org.controlsfx.control.PopOver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -264,11 +264,11 @@ private void onSendMessage(String userPrompt) {
updatePromptHistory();
setLoading(true);

BackgroundTask<AiMessage> task =
BackgroundTask<LlmMessage> task =
BackgroundTask
.wrap(() -> aiChatLogic.execute(userMessage))
.showToUser(true)
.onSuccess(aiMessage -> {
.onSuccess(llmMessage -> {
setLoading(false);
chatPrompt.requestPromptFocus();
})
Expand Down Expand Up @@ -299,7 +299,7 @@ private void addError(String error) {

private void updatePromptHistory() {
chatPrompt.getHistory().clear();
chatPrompt.getHistory().addAll(getReversedUserMessagesStream().map(UserMessage::singleText).toList());
chatPrompt.getHistory().addAll(getReversedUserMessagesStream().map(UserMessage::getText).toList());
}

private Stream<UserMessage> getReversedUserMessagesStream() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;

import dev.langchain4j.data.message.ChatMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;

/**
* Main class for AI chatting. It checks if the AI features are enabled and if the embedding model is properly set up.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;

import dev.langchain4j.data.message.ChatMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;

public class AiChatWindow extends BaseWindow {
private final AiService aiService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.jabref.gui.util.UiTaskExecutor;

import com.airhacks.afterburner.views.ViewLoader;
import dev.langchain4j.data.message.ChatMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;

public class ChatHistoryComponent extends ScrollPane {
@FXML private VBox vBox;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import org.jabref.logic.l10n.Localization;

import com.airhacks.afterburner.views.ViewLoader;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;
import org.jabref.logic.ai.framework.messages.LlmMessage;
import org.jabref.logic.ai.framework.messages.UserMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -74,32 +74,31 @@ public void setOnDelete(Consumer<ChatMessageComponent> onDeleteCallback) {
}

private void loadChatMessage() {
switch (chatMessage.get()) {
case UserMessage userMessage -> {
setColor("-jr-ai-message-user", "-jr-ai-message-user-border");
setNodeOrientation(NodeOrientation.RIGHT_TO_LEFT);
wrapperHBox.setAlignment(Pos.TOP_RIGHT);
sourceLabel.setText(Localization.lang("User"));
markdownTextFlow.setMarkdown(userMessage.singleText());
}

case AiMessage aiMessage -> {
setColor("-jr-ai-message-ai", "-jr-ai-message-ai-border");
setNodeOrientation(NodeOrientation.LEFT_TO_RIGHT);
wrapperHBox.setAlignment(Pos.TOP_LEFT);
sourceLabel.setText(Localization.lang("AI"));
markdownTextFlow.setMarkdown(aiMessage.text());
}

case ErrorMessage errorMessage -> {
setColor("-jr-ai-message-error", "-jr-ai-message-error-border");
setNodeOrientation(NodeOrientation.LEFT_TO_RIGHT);
sourceLabel.setText(Localization.lang("Error"));
markdownTextFlow.setMarkdown(errorMessage.getText());
}

default ->
LOGGER.error("ChatMessageComponent supports only user, AI, or error messages, but other type was passed: {}", chatMessage.get().type().name());
ChatMessage message = chatMessage.get();

if (message instanceof UserMessage userMessage) {
setColor("-jr-ai-message-user", "-jr-ai-message-user-border");
setNodeOrientation(NodeOrientation.RIGHT_TO_LEFT);
wrapperHBox.setAlignment(Pos.TOP_RIGHT);
sourceLabel.setText(Localization.lang("User"));
markdownTextFlow.setMarkdown(userMessage.getText());
} else if (message instanceof LlmMessage llmMessage) {
setColor("-jr-ai-message-ai", "-jr-ai-message-ai-border");
setNodeOrientation(NodeOrientation.LEFT_TO_RIGHT);
wrapperHBox.setAlignment(Pos.TOP_LEFT);
sourceLabel.setText(Localization.lang("AI"));
markdownTextFlow.setMarkdown(llmMessage.getText());
} else if (message instanceof ErrorMessage errorMessage) {
setColor("-jr-ai-message-error", "-jr-ai-message-error-border");
setNodeOrientation(NodeOrientation.LEFT_TO_RIGHT);
sourceLabel.setText(Localization.lang("Error"));
markdownTextFlow.setMarkdown(errorMessage.getText());
} else {
setColor("-jr-ai-message-ai", "-jr-ai-message-ai-border");
setNodeOrientation(NodeOrientation.LEFT_TO_RIGHT);
wrapperHBox.setAlignment(Pos.TOP_LEFT);
sourceLabel.setText(Localization.lang("Unknown"));
markdownTextFlow.setMarkdown(message.getText());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import org.jabref.model.metadata.MetaData;

import com.tobiasdiez.easybind.EasyBind;
import dev.langchain4j.data.message.ChatMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;
import org.jspecify.annotations.NonNull;

public class GroupTreeViewModel extends AbstractViewModel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.jabref.model.groups.GroupTreeNode;

import com.google.common.eventbus.Subscribe;
import dev.langchain4j.data.message.ChatMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import java.nio.file.Path;
import java.util.List;

import dev.langchain4j.data.message.ChatMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;

public interface ChatHistoryStorage {
List<ChatMessage> loadMessagesForEntry(Path bibDatabasePath, String citationKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.NotificationService;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import org.jabref.logic.ai.framework.messages.ChatMessage;
import org.jabref.logic.ai.framework.messages.LlmMessage;
import org.jabref.logic.ai.framework.messages.UserMessage;
import kotlin.ranges.IntRange;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -33,34 +33,28 @@ public static ChatHistoryRecord fromLangchainMessage(ChatMessage chatMessage) {
}

private static String getContentFromLangchainMessage(ChatMessage chatMessage) {
String content;

switch (chatMessage) {
case AiMessage aiMessage ->
content = aiMessage.text();
case UserMessage userMessage ->
content = userMessage.singleText();
case ErrorMessage errorMessage ->
content = errorMessage.getText();
default -> {
LOGGER.warn("ChatHistoryRecord supports only AI, user. and error messages, but added message has other type: {}", chatMessage.type().name());
return "";
}
if (chatMessage instanceof LlmMessage llmMessage) {
return llmMessage.getText();
} else if (chatMessage instanceof UserMessage userMessage) {
return userMessage.getText();
} else if (chatMessage instanceof ErrorMessage errorMessage) {
return errorMessage.getText();
} else {
LOGGER.warn("ChatHistoryRecord supports only AI, user, and error messages, but added message has other type: {}", chatMessage.getClass().getSimpleName());
return "";
}

return content;
}

public ChatMessage toLangchainMessage() {
if (className.equals(AiMessage.class.getName())) {
return new AiMessage(content);
} else if (className.equals(UserMessage.class.getName())) {
if (className.equals(LlmMessage.class.getName()) || className.equals("dev.langchain4j.data.message.AiMessage")) {
return new LlmMessage(content);
} else if (className.equals(UserMessage.class.getName()) || className.equals("dev.langchain4j.data.message.UserMessage")) {
return new UserMessage(content);
} else if (className.equals(ErrorMessage.class.getName())) {
return new ErrorMessage(content);
} else {
LOGGER.warn("ChatHistoryRecord supports only AI and user messages, but retrieved message has other type: {}. Will treat as an AI message.", className);
return new AiMessage(content);
return new LlmMessage(content);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.jabref.logic.ai.framework.embeddings;

import org.jabref.logic.l10n.Localization;

/**
* Exception thrown when embedding computation fails.
*/
public class EmbeddingComputationException extends EmbeddingException {

@Override
public String getLocalizedMessage() {
return Localization.lang("Failed to compute embeddings.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.jabref.logic.ai.framework.embeddings;

import org.jabref.logic.l10n.Localization;

/**
* Exception thrown when embedding operations fail.
*/
public class EmbeddingException extends Exception {

@Override
public String getLocalizedMessage() {
return Localization.lang("An error occurred during embedding computation.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.jabref.logic.ai.framework.embeddings;

import java.util.List;

/**
* Interface for embedding model implementations.
*/
public interface EmbeddingModel {

/**
* Computes an embedding vector for the given text.
*
* @param text the text to embed
* @param type the type of embedding to generate
* @return the embedding vector as a list of floats
* @throws EmbeddingException if embedding computation fails
*/
List<Float> compute(String text, EmbeddingType type) throws EmbeddingException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.jabref.logic.ai.framework.embeddings;

import org.jabref.logic.l10n.Localization;

/**
* Exception thrown when the embedding model is not available.
*/
public class EmbeddingModelNotAvailableException extends EmbeddingException {

@Override
public String getLocalizedMessage() {
return Localization.lang("Embedding model is not available.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.jabref.logic.ai.framework.embeddings;

/**
* Types of embeddings that can be generated for different use cases.
*/
public enum EmbeddingType {
/**
* Embedding optimized for question-like content.
*/
QUESTION,

/**
* Embedding optimized for answer-like content.
*/
ANSWER
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.jabref.logic.ai.framework.llms;

import org.jabref.logic.l10n.Localization;

/**
* Exception thrown when LLM authentication fails.
*/
public class LlmAuthenticationException extends LlmInferenceException {

@Override
public String getLocalizedMessage() {
return Localization.lang("LLM authentication failed.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.jabref.logic.ai.framework.llms;

import org.jabref.logic.l10n.Localization;

/**
* Exception thrown when connection to LLM service fails.
*/
public class LlmConnectionException extends LlmInferenceException {

@Override
public String getLocalizedMessage() {
return Localization.lang("Failed to connect to LLM service.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.jabref.logic.ai.framework.llms;

import org.jabref.logic.l10n.Localization;

/**
* Exception thrown when LLM inference fails.
*/
public class LlmInferenceException extends Exception {

@Override
public String getLocalizedMessage() {
return Localization.lang("An error occurred during LLM inference.");
}
}
Loading