2020import java .util .HashMap ;
2121import java .util .List ;
2222import java .util .Map ;
23- import java .util .stream .Collectors ;
2423
25- import reactor .core .publisher .Flux ;
24+ import org .slf4j .Logger ;
25+ import org .slf4j .LoggerFactory ;
26+ import reactor .core .scheduler .Scheduler ;
27+ import reactor .core .scheduler .Schedulers ;
2628
27- import org .springframework .ai .chat .client .advisor .AbstractChatMemoryAdvisor ;
2829import org .springframework .ai .chat .client .ChatClientRequest ;
2930import org .springframework .ai .chat .client .ChatClientResponse ;
30- import org .springframework .ai .chat .client .advisor .api .CallAdvisorChain ;
31- import org .springframework .ai .chat .client .advisor .api .StreamAdvisorChain ;
31+ import org .springframework .ai .chat .client .advisor .api .Advisor ;
32+ import org .springframework .ai .chat .client .advisor .api .AdvisorChain ;
33+ import org .springframework .ai .chat .client .advisor .api .BaseAdvisor ;
34+ import org .springframework .ai .chat .client .advisor .api .BaseChatMemoryAdvisor ;
35+ import org .springframework .ai .chat .memory .ChatMemory ;
3236import org .springframework .ai .chat .messages .AssistantMessage ;
3337import org .springframework .ai .chat .messages .Message ;
3438import org .springframework .ai .chat .messages .MessageType ;
35- import org .springframework .ai .chat .messages .SystemMessage ;
3639import org .springframework .ai .chat .messages .UserMessage ;
37- import org .springframework .ai .chat .model .MessageAggregator ;
3840import org .springframework .ai .chat .prompt .PromptTemplate ;
3941import org .springframework .ai .document .Document ;
40- import org .springframework .ai .vectorstore .SearchRequest ;
4142import org .springframework .ai .vectorstore .VectorStore ;
4243
4344/**
4849 * @author Christian Tzolov
4950 * @author Thomas Vitale
5051 * @author Oganes Bozoyan
52+ * @author Mark Pollack
5153 * @since 1.0.0
5254 */
53- public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor <VectorStore > {
55+ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
56+
57+ public static final String CHAT_MEMORY_RETRIEVE_SIZE_KEY = "chat_memory_response_size" ;
5458
5559 private static final String DOCUMENT_METADATA_CONVERSATION_ID = "conversationId" ;
5660
5761 private static final String DOCUMENT_METADATA_MESSAGE_TYPE = "messageType" ;
5862
63+ /**
64+ * The default chat memory retrieve size to use when no retrieve size is provided.
65+ */
66+ public static final int DEFAULT_TOP_K = 20 ;
67+
5968 private static final PromptTemplate DEFAULT_SYSTEM_PROMPT_TEMPLATE = new PromptTemplate ("""
6069 {instructions}
6170
@@ -69,71 +78,84 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
6978
7079 private final PromptTemplate systemPromptTemplate ;
7180
72- private VectorStoreChatMemoryAdvisor (VectorStore vectorStore , String defaultConversationId ,
73- int chatHistoryWindowSize , boolean protectFromBlocking , PromptTemplate systemPromptTemplate , int order ) {
74- super (vectorStore , defaultConversationId , chatHistoryWindowSize , protectFromBlocking , order );
81+ protected final int defaultChatMemoryRetrieveSize ;
82+
83+ private final String defaultConversationId ;
84+
85+ private final int order ;
86+
87+ private final Scheduler scheduler ;
88+
89+ private VectorStore vectorStore ;
90+
91+ public VectorStoreChatMemoryAdvisor (PromptTemplate systemPromptTemplate , int defaultChatMemoryRetrieveSize ,
92+ String defaultConversationId , int order , Scheduler scheduler , VectorStore vectorStore ) {
7593 this .systemPromptTemplate = systemPromptTemplate ;
94+ this .defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize ;
95+ this .defaultConversationId = defaultConversationId ;
96+ this .order = order ;
97+ this .scheduler = scheduler ;
98+ this .vectorStore = vectorStore ;
7699 }
77100
78101 public static Builder builder (VectorStore chatMemory ) {
79102 return new Builder (chatMemory );
80103 }
81104
82105 @ Override
83- public ChatClientResponse adviseCall (ChatClientRequest chatClientRequest , CallAdvisorChain callAdvisorChain ) {
84- chatClientRequest = this .before (chatClientRequest );
85-
86- ChatClientResponse chatClientResponse = callAdvisorChain .nextCall (chatClientRequest );
87-
88- this .after (chatClientResponse );
89-
90- return chatClientResponse ;
106+ public int getOrder () {
107+ return order ;
91108 }
92109
93110 @ Override
94- public Flux <ChatClientResponse > adviseStream (ChatClientRequest chatClientRequest ,
95- StreamAdvisorChain streamAdvisorChain ) {
96- Flux <ChatClientResponse > chatClientResponses = this .doNextWithProtectFromBlockingBefore (chatClientRequest ,
97- streamAdvisorChain , this ::before );
98-
99- return new MessageAggregator ().aggregateChatClientResponse (chatClientResponses , this ::after );
111+ public Scheduler getScheduler () {
112+ return this .scheduler ;
100113 }
101114
102- private ChatClientRequest before (ChatClientRequest chatClientRequest ) {
103- String conversationId = this .doGetConversationId (chatClientRequest .context ());
104- int chatMemoryRetrieveSize = this .doGetChatMemoryRetrieveSize (chatClientRequest .context ());
105-
106- // 1. Retrieve the chat memory for the current conversation.
107- var searchRequest = SearchRequest .builder ()
108- .query (chatClientRequest .prompt ().getUserMessage ().getText ())
109- .topK (chatMemoryRetrieveSize )
110- .filterExpression (DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'" )
115+ @ Override
116+ public ChatClientRequest before (ChatClientRequest request , AdvisorChain advisorChain ) {
117+ String conversationId = getConversationId (request .context ());
118+ String query = request .prompt ().getUserMessage () != null ? request .prompt ().getUserMessage ().getText () : "" ;
119+ int topK = getChatMemoryTopK (request .context ());
120+ String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'" ;
121+ var searchRequest = org .springframework .ai .vectorstore .SearchRequest .builder ()
122+ .query (query )
123+ .topK (topK )
124+ .filterExpression (filter )
111125 .build ();
126+ java .util .List <org .springframework .ai .document .Document > documents = this .vectorStore
127+ .similaritySearch (searchRequest );
112128
113- List <Document > documents = this .getChatMemoryStore ().similaritySearch (searchRequest );
114-
115- // 2. Processed memory messages as a string.
116129 String longTermMemory = documents == null ? ""
117- : documents .stream ().map (Document ::getText ).collect (Collectors .joining (System .lineSeparator ()));
130+ : documents .stream ()
131+ .map (org .springframework .ai .document .Document ::getText )
132+ .collect (java .util .stream .Collectors .joining (System .lineSeparator ()));
118133
119- // 2. Augment the system message.
120- SystemMessage systemMessage = chatClientRequest .prompt ().getSystemMessage ();
134+ org .springframework .ai .chat .messages .SystemMessage systemMessage = request .prompt ().getSystemMessage ();
121135 String augmentedSystemText = this .systemPromptTemplate
122- .render (Map .of ("instructions" , systemMessage .getText (), "long_term_memory" , longTermMemory ));
136+ .render (java . util . Map .of ("instructions" , systemMessage .getText (), "long_term_memory" , longTermMemory ));
123137
124- // 3. Create a new request with the augmented system message.
125- ChatClientRequest processedChatClientRequest = chatClientRequest .mutate ()
126- .prompt (chatClientRequest .prompt ().augmentSystemMessage (augmentedSystemText ))
138+ ChatClientRequest processedChatClientRequest = request .mutate ()
139+ .prompt (request .prompt ().augmentSystemMessage (augmentedSystemText ))
127140 .build ();
128141
129- // 4. Add the new user message to the conversation memory.
130- UserMessage userMessage = processedChatClientRequest .prompt ().getUserMessage ();
131- this .getChatMemoryStore ().write (toDocuments (List .of (userMessage ), conversationId ));
142+ org .springframework .ai .chat .messages .UserMessage userMessage = processedChatClientRequest .prompt ()
143+ .getUserMessage ();
144+ if (userMessage != null ) {
145+ this .vectorStore .write (toDocuments (java .util .List .of (userMessage ), conversationId ));
146+ }
132147
133148 return processedChatClientRequest ;
134149 }
135150
136- private void after (ChatClientResponse chatClientResponse ) {
151+ private int getChatMemoryTopK (Map <String , Object > context ) {
152+ return context .containsKey (CHAT_MEMORY_RETRIEVE_SIZE_KEY )
153+ ? Integer .parseInt (context .get (CHAT_MEMORY_RETRIEVE_SIZE_KEY ).toString ())
154+ : this .defaultChatMemoryRetrieveSize ;
155+ }
156+
157+ @ Override
158+ public ChatClientResponse after (ChatClientResponse chatClientResponse , AdvisorChain advisorChain ) {
137159 List <Message > assistantMessages = new ArrayList <>();
138160 if (chatClientResponse .chatResponse () != null ) {
139161 assistantMessages = chatClientResponse .chatResponse ()
@@ -142,8 +164,8 @@ private void after(ChatClientResponse chatClientResponse) {
142164 .map (g -> (Message ) g .getOutput ())
143165 .toList ();
144166 }
145- this .getChatMemoryStore ()
146- . write ( toDocuments ( assistantMessages , this . doGetConversationId ( chatClientResponse . context ()))) ;
167+ this .vectorStore . write ( toDocuments ( assistantMessages , this . getConversationId ( chatClientResponse . context ())));
168+ return chatClientResponse ;
147169 }
148170
149171 private List <Document > toDocuments (List <Message > messages , String conversationId ) {
@@ -173,28 +195,93 @@ else if (message instanceof AssistantMessage assistantMessage) {
173195 return docs ;
174196 }
175197
176- public static class Builder extends AbstractChatMemoryAdvisor .AbstractBuilder <VectorStore > {
198+ /**
199+ * Builder for VectorStoreChatMemoryAdvisor.
200+ */
201+ public static class Builder {
177202
178203 private PromptTemplate systemPromptTemplate = DEFAULT_SYSTEM_PROMPT_TEMPLATE ;
179204
180- protected Builder (VectorStore chatMemory ) {
181- super (chatMemory );
182- }
205+ private Integer topK = DEFAULT_TOP_K ;
183206
184- public Builder systemTextAdvise (String systemTextAdvise ) {
185- this .systemPromptTemplate = new PromptTemplate (systemTextAdvise );
186- return this ;
207+ private String conversationId = ChatMemory .DEFAULT_CONVERSATION_ID ;
208+
209+ private Scheduler scheduler ;
210+
211+ private int order = Advisor .DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER ;
212+
213+ private VectorStore vectorStore ;
214+
215+ /**
216+ * Creates a new builder instance.
217+ * @param vectorStore the vector store to use
218+ */
219+ protected Builder (VectorStore vectorStore ) {
220+ this .vectorStore = vectorStore ;
187221 }
188222
223+ /**
224+ * Set the system prompt template.
225+ * @param systemPromptTemplate the system prompt template
226+ * @return this builder
227+ */
189228 public Builder systemPromptTemplate (PromptTemplate systemPromptTemplate ) {
190229 this .systemPromptTemplate = systemPromptTemplate ;
191230 return this ;
192231 }
193232
194- @ Override
233+ /**
234+ * Set the chat memory retrieve size.
235+ * @param topK the chat memory retrieve size
236+ * @return this builder
237+ */
238+ public Builder topK (int topK ) {
239+ this .topK = topK ;
240+ return this ;
241+ }
242+
243+ /**
244+ * Set the conversation id.
245+ * @param conversationId the conversation id
246+ * @return the builder
247+ */
248+ public Builder conversationId (String conversationId ) {
249+ this .conversationId = conversationId ;
250+ return this ;
251+ }
252+
253+ /**
254+ * Set whether to protect from blocking.
255+ * @param protectFromBlocking whether to protect from blocking
256+ * @return the builder
257+ */
258+ public Builder protectFromBlocking (boolean protectFromBlocking ) {
259+ this .scheduler = protectFromBlocking ? BaseAdvisor .DEFAULT_SCHEDULER : Schedulers .immediate ();
260+ return this ;
261+ }
262+
263+ public Builder scheduler (Scheduler scheduler ) {
264+ this .scheduler = scheduler ;
265+ return this ;
266+ }
267+
268+ /**
269+ * Set the order.
270+ * @param order the order
271+ * @return the builder
272+ */
273+ public Builder order (int order ) {
274+ this .order = order ;
275+ return this ;
276+ }
277+
278+ /**
279+ * Build the advisor.
280+ * @return the advisor
281+ */
195282 public VectorStoreChatMemoryAdvisor build () {
196- return new VectorStoreChatMemoryAdvisor (this .chatMemory , this .conversationId , this .chatMemoryRetrieveSize ,
197- this .protectFromBlocking , this .systemPromptTemplate , this .order );
283+ return new VectorStoreChatMemoryAdvisor (this .systemPromptTemplate , this .topK , this .conversationId ,
284+ this .order , this .scheduler , this .vectorStore );
198285 }
199286
200287 }
0 commit comments