Skip to content

Commit 897bd36

Browse files
committed
feat: add incremental refresh support to ChatMemoryRepository
1 parent 5bc46dd commit 897bd36

File tree

3 files changed

+73
-26
lines changed

3 files changed

+73
-26
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,15 @@ public interface ChatMemoryRepository {
4040

4141
void deleteByConversationId(String conversationId);
4242

43+
/**
44+
* Atomically removes the messages in {@code deletes} and adds the messages in {@code adds}
45+
* for the given conversation ID. This provides a more efficient way to update
46+
* the memory than reading the entire history and overwriting it.
47+
*
48+
* @param conversationId The ID of the conversation to update.
49+
* @param deletes A list of messages to be removed from the memory.
50+
* @param adds A list of new messages to be added to the memory.
51+
*/
52+
void refresh(String conversationId, List<Message> deletes, List<Message> adds);
53+
4354
}

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,17 @@ public void deleteByConversationId(String conversationId) {
6060
this.chatMemoryStore.remove(conversationId);
6161
}
6262

63+
@Override
64+
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
65+
this.chatMemoryStore.compute(conversationId, (key, currentMessages) -> {
66+
if (currentMessages == null) {
67+
return new ArrayList<>(adds);
68+
}
69+
List<Message> updatedMessages = new ArrayList<>(currentMessages);
70+
updatedMessages.removeAll(deletes);
71+
updatedMessages.addAll(adds);
72+
return updatedMessages;
73+
});
74+
}
75+
6376
}

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.HashSet;
21+
import java.util.LinkedHashSet;
2122
import java.util.List;
2223
import java.util.Set;
2324

@@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
6162
Assert.noNullElements(messages, "messages cannot contain null elements");
6263

6364
List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
64-
List<Message> processedMessages = process(memoryMessages, messages);
65-
this.chatMemoryRepository.saveAll(conversationId, processedMessages);
65+
MessageChanges changes = process(memoryMessages, messages);
66+
if (!changes.toDelete.isEmpty() || !changes.toAdd.isEmpty()) {
67+
this.chatMemoryRepository.refresh(conversationId, changes.toDelete, changes.toAdd);
68+
}
6669
}
6770

6871
@Override
@@ -77,38 +80,58 @@ public void clear(String conversationId) {
7780
this.chatMemoryRepository.deleteByConversationId(conversationId);
7881
}
7982

80-
private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
81-
List<Message> processedMessages = new ArrayList<>();
83+
private MessageChanges process(List<Message> memoryMessages, List<Message> newMessages) {
84+
Set<Message> originalMessageSet = new LinkedHashSet<>(memoryMessages);
85+
List<Message> uniqueNewMessages = newMessages.stream()
86+
.filter(msg -> !originalMessageSet.contains(msg))
87+
.toList();
88+
boolean hasNewSystemMessage = uniqueNewMessages.stream().anyMatch(SystemMessage.class::isInstance);
89+
90+
List<Message> finalMessages = new ArrayList<>();
91+
if (hasNewSystemMessage) {
92+
memoryMessages.stream().filter(msg -> !(msg instanceof SystemMessage)).forEach(finalMessages::add);
93+
finalMessages.addAll(uniqueNewMessages);
94+
}
95+
else {
96+
finalMessages.addAll(memoryMessages);
97+
finalMessages.addAll(uniqueNewMessages);
98+
}
8299

83-
Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
84-
boolean hasNewSystemMessage = newMessages.stream()
85-
.filter(SystemMessage.class::isInstance)
86-
.anyMatch(message -> !memoryMessagesSet.contains(message));
100+
if (finalMessages.size() > this.maxMessages) {
101+
List<Message> trimmedMessages = new ArrayList<>();
102+
int messagesToRemove = finalMessages.size() - this.maxMessages;
103+
int removed = 0;
104+
for (Message message : finalMessages) {
105+
if (message instanceof SystemMessage || removed >= messagesToRemove) {
106+
trimmedMessages.add(message);
107+
}
108+
else {
109+
removed++;
110+
}
111+
}
112+
finalMessages = trimmedMessages;
113+
}
87114

88-
memoryMessages.stream()
89-
.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
90-
.forEach(processedMessages::add);
115+
Set<Message> finalMessageSet = new LinkedHashSet<>(finalMessages);
91116

92-
processedMessages.addAll(newMessages);
117+
List<Message> toDelete = originalMessageSet.stream().filter(m -> !finalMessageSet.contains(m)).toList();
93118

94-
if (processedMessages.size() <= this.maxMessages) {
95-
return processedMessages;
96-
}
119+
List<Message> toAdd = finalMessageSet.stream().filter(m -> !originalMessageSet.contains(m)).toList();
97120

98-
int messagesToRemove = processedMessages.size() - this.maxMessages;
121+
return new MessageChanges(toDelete, toAdd);
122+
}
99123

100-
List<Message> trimmedMessages = new ArrayList<>();
101-
int removed = 0;
102-
for (Message message : processedMessages) {
103-
if (message instanceof SystemMessage || removed >= messagesToRemove) {
104-
trimmedMessages.add(message);
105-
}
106-
else {
107-
removed++;
108-
}
124+
private static class MessageChanges {
125+
126+
final List<Message> toDelete;
127+
128+
final List<Message> toAdd;
129+
130+
MessageChanges(List<Message> toDelete, List<Message> toAdd) {
131+
this.toDelete = toDelete;
132+
this.toAdd = toAdd;
109133
}
110134

111-
return trimmedMessages;
112135
}
113136

114137
public static Builder builder() {

0 commit comments

Comments
 (0)