diff --git a/CHANGELOG.md b/CHANGELOG.md index 861471f2d69..7ce2a5a2944 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ Note that this project **does not** adhere to [Semantic Versioning](https://semv ### Added +- We added a "Regenerate" button for the AI chat allowing the user to make the LLM reformulate its answer to the previous prompt [#12191](https://github.com/JabRef/jabref/issues/12191) + ### Changed ### Fixed diff --git a/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/AiChatComponent.java b/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/AiChatComponent.java index 17abb13534c..5394da40fc5 100644 --- a/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/AiChatComponent.java +++ b/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/AiChatComponent.java @@ -43,6 +43,7 @@ import com.google.common.annotations.VisibleForTesting; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; import dev.langchain4j.data.message.UserMessage; import org.controlsfx.control.PopOver; import org.slf4j.Logger; @@ -191,6 +192,20 @@ private void initializeChatPrompt() { onSendMessage(userMessage); }); + chatPrompt.setRegenerateCallback(() -> { + String lastUserPrompt = ""; + if (!aiChatLogic.getChatHistory().isEmpty()) { + lastUserPrompt = getLastUserMessage().singleText(); + deleteLastMessage(); + deleteLastMessage(); + } + + chatPrompt.switchToNormalState(); + if (!lastUserPrompt.isEmpty()) { + onSendMessage(lastUserPrompt); + } + }); + chatPrompt.requestPromptFocus(); updatePromptHistory(); @@ -334,4 +349,19 @@ private void deleteLastMessage() { aiChatLogic.getChatHistory().remove(index); } } + + private UserMessage getLastUserMessage() { + if (!aiChatLogic.getChatHistory().isEmpty() && aiChatLogic.getChatHistory().size() >= 2) { + int userMessageIndex = aiChatLogic.getChatHistory().size() - 2; + ChatMessage chat = aiChatLogic.getChatHistory().get(userMessageIndex); + + if (chat.type() == ChatMessageType.USER) { + return (UserMessage) chat; + } else { + return new UserMessage(""); + } + } else { + return new UserMessage(""); + } + } } diff --git a/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.java b/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.java index 890e1761861..30bcd03a1fc 100644 --- a/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.java +++ b/jabgui/src/main/java/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.java @@ -30,6 +30,7 @@ public class ChatPromptComponent extends HBox { private final ObjectProperty> sendCallback = new SimpleObjectProperty<>(); private final ObjectProperty> retryCallback = new SimpleObjectProperty<>(); private final ObjectProperty cancelCallback = new SimpleObjectProperty<>(); + private final ObjectProperty regenerateCallback = new SimpleObjectProperty<>(); private final ListProperty history = new SimpleListProperty<>(FXCollections.observableArrayList()); @@ -44,6 +45,7 @@ public class ChatPromptComponent extends HBox { @FXML private ExpandingTextArea userPromptTextArea; @FXML private Button submitButton; + @FXML private Button regenerateButton; public ChatPromptComponent() { ViewLoader.view(this) @@ -68,6 +70,10 @@ public void setCancelCallback(Runnable cancelCallback) { this.cancelCallback.set(cancelCallback); } + public void setRegenerateCallback(Runnable regenerateCallback) { + this.regenerateCallback.set(regenerateCallback); + } + public ListProperty getHistory() { return history; } @@ -174,6 +180,7 @@ public void switchToNormalState() { this.getChildren().clear(); this.getChildren().add(userPromptTextArea); this.getChildren().add(submitButton); + this.getChildren().add(regenerateButton); requestPromptFocus(); } @@ -191,4 +198,11 @@ private void onSendMessage() { sendCallback.get().accept(userPrompt); } } + + @FXML + private void onRegenerateMessage() { + if (regenerateCallback.get() != null) { + regenerateCallback.get().run(); + } + } } diff --git a/jabgui/src/main/resources/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.fxml b/jabgui/src/main/resources/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.fxml index 04c53eaab88..fdf05978f8d 100644 --- a/jabgui/src/main/resources/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.fxml +++ b/jabgui/src/main/resources/org/jabref/gui/ai/components/aichat/chatprompt/ChatPromptComponent.fxml @@ -18,4 +18,8 @@ mnemonicParsing="false" onAction="#onSendMessage" text="%Submit"/> +