1818
1919import java .util .ArrayList ;
2020import java .util .HashSet ;
21+ import java .util .LinkedHashSet ;
2122import java .util .List ;
2223import java .util .Set ;
2324
@@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
6162 Assert .noNullElements (messages , "messages cannot contain null elements" );
6263
6364 List <Message > memoryMessages = this .chatMemoryRepository .findByConversationId (conversationId );
64- List <Message > processedMessages = process (memoryMessages , messages );
65- this .chatMemoryRepository .saveAll (conversationId , processedMessages );
65+ MessageChanges changes = process (memoryMessages , messages );
66+ if (!changes .toDelete .isEmpty () || !changes .toAdd .isEmpty ()) {
67+ this .chatMemoryRepository .refresh (conversationId , changes .toDelete , changes .toAdd );
68+ }
6669 }
6770
6871 @ Override
@@ -77,38 +80,58 @@ public void clear(String conversationId) {
7780 this .chatMemoryRepository .deleteByConversationId (conversationId );
7881 }
7982
80- private List <Message > process (List <Message > memoryMessages , List <Message > newMessages ) {
81- List <Message > processedMessages = new ArrayList <>();
83+ private MessageChanges process (List <Message > memoryMessages , List <Message > newMessages ) {
84+ Set <Message > originalMessageSet = new LinkedHashSet <>(memoryMessages );
85+ List <Message > uniqueNewMessages = newMessages .stream ()
86+ .filter (msg -> !originalMessageSet .contains (msg ))
87+ .toList ();
88+ boolean hasNewSystemMessage = uniqueNewMessages .stream ().anyMatch (SystemMessage .class ::isInstance );
89+
90+ List <Message > finalMessages = new ArrayList <>();
91+ if (hasNewSystemMessage ) {
92+ memoryMessages .stream ().filter (msg -> !(msg instanceof SystemMessage )).forEach (finalMessages ::add );
93+ finalMessages .addAll (uniqueNewMessages );
94+ }
95+ else {
96+ finalMessages .addAll (memoryMessages );
97+ finalMessages .addAll (uniqueNewMessages );
98+ }
8299
83- Set <Message > memoryMessagesSet = new HashSet <>(memoryMessages );
84- boolean hasNewSystemMessage = newMessages .stream ()
85- .filter (SystemMessage .class ::isInstance )
86- .anyMatch (message -> !memoryMessagesSet .contains (message ));
100+ if (finalMessages .size () > this .maxMessages ) {
101+ List <Message > trimmedMessages = new ArrayList <>();
102+ int messagesToRemove = finalMessages .size () - this .maxMessages ;
103+ int removed = 0 ;
104+ for (Message message : finalMessages ) {
105+ if (message instanceof SystemMessage || removed >= messagesToRemove ) {
106+ trimmedMessages .add (message );
107+ }
108+ else {
109+ removed ++;
110+ }
111+ }
112+ finalMessages = trimmedMessages ;
113+ }
87114
88- memoryMessages .stream ()
89- .filter (message -> !(hasNewSystemMessage && message instanceof SystemMessage ))
90- .forEach (processedMessages ::add );
115+ Set <Message > finalMessageSet = new LinkedHashSet <>(finalMessages );
91116
92- processedMessages . addAll ( newMessages );
117+ List < Message > toDelete = originalMessageSet . stream (). filter ( m -> ! finalMessageSet . contains ( m )). toList ( );
93118
94- if (processedMessages .size () <= this .maxMessages ) {
95- return processedMessages ;
96- }
119+ List <Message > toAdd = finalMessageSet .stream ().filter (m -> !originalMessageSet .contains (m )).toList ();
97120
98- int messagesToRemove = processedMessages .size () - this .maxMessages ;
121+ return new MessageChanges (toDelete , toAdd );
122+ }
99123
100- List < Message > trimmedMessages = new ArrayList <>();
101- int removed = 0 ;
102- for ( Message message : processedMessages ) {
103- if ( message instanceof SystemMessage || removed >= messagesToRemove ) {
104- trimmedMessages . add ( message ) ;
105- }
106- else {
107- removed ++ ;
108- }
124+ private static class MessageChanges {
125+
126+ final List < Message > toDelete ;
127+
128+ final List < Message > toAdd ;
129+
130+ MessageChanges ( List < Message > toDelete , List < Message > toAdd ) {
131+ this . toDelete = toDelete ;
132+ this . toAdd = toAdd ;
109133 }
110134
111- return trimmedMessages ;
112135 }
113136
114137 public static Builder builder () {
0 commit comments