Skip to content

Commit 2213483

Browse files
digitpicSenreySong
authored andcommitted
refactor: Simplifies VectorStoreChatMemoryAdvisor code
Refactors VectorStoreChatMemoryAdvisor for improved readability and maintainability. Uses local variables to hold intermediate results, and ensures type declarations for clarity. Signed-off-by: Kyuwon Jeong <[email protected]>
1 parent 715aa8c commit 2213483

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.stream.Collectors;
2324

2425
import reactor.core.publisher.Flux;
2526
import reactor.core.publisher.Mono;
@@ -37,9 +38,11 @@
3738
import org.springframework.ai.chat.messages.AssistantMessage;
3839
import org.springframework.ai.chat.messages.Message;
3940
import org.springframework.ai.chat.messages.MessageType;
41+
import org.springframework.ai.chat.messages.SystemMessage;
4042
import org.springframework.ai.chat.messages.UserMessage;
4143
import org.springframework.ai.chat.prompt.PromptTemplate;
4244
import org.springframework.ai.document.Document;
45+
import org.springframework.ai.vectorstore.SearchRequest;
4346
import org.springframework.ai.vectorstore.VectorStore;
4447
import org.springframework.util.Assert;
4548

@@ -122,31 +125,23 @@ public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorC
122125
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
123126
int topK = getChatMemoryTopK(request.context());
124127
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
125-
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
126-
.query(query)
127-
.topK(topK)
128-
.filterExpression(filter)
129-
.build();
130-
java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore
131-
.similaritySearch(searchRequest);
128+
SearchRequest searchRequest = SearchRequest.builder().query(query).topK(topK).filterExpression(filter).build();
129+
List<Document> documents = this.vectorStore.similaritySearch(searchRequest);
132130

133131
String longTermMemory = documents == null ? ""
134-
: documents.stream()
135-
.map(org.springframework.ai.document.Document::getText)
136-
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));
132+
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
137133

138-
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
134+
SystemMessage systemMessage = request.prompt().getSystemMessage();
139135
String augmentedSystemText = this.systemPromptTemplate
140-
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
136+
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
141137

142138
ChatClientRequest processedChatClientRequest = request.mutate()
143139
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
144140
.build();
145141

146-
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
147-
.getUserMessage();
142+
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
148143
if (userMessage != null) {
149-
this.vectorStore.write(toDocuments(java.util.List.of(userMessage), conversationId));
144+
this.vectorStore.write(toDocuments(List.of(userMessage), conversationId));
150145
}
151146

152147
return processedChatClientRequest;
@@ -186,10 +181,11 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
186181
}
187182

188183
private List<Document> toDocuments(List<Message> messages, String conversationId) {
189-
List<Document> docs = messages.stream()
184+
return messages.stream()
190185
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
191186
.map(message -> {
192-
var metadata = new HashMap<>(message.getMetadata() != null ? message.getMetadata() : new HashMap<>());
187+
Map<String, Object> metadata = new HashMap<>(
188+
message.getMetadata() != null ? message.getMetadata() : new HashMap<>());
193189
metadata.put(DOCUMENT_METADATA_CONVERSATION_ID, conversationId);
194190
metadata.put(DOCUMENT_METADATA_MESSAGE_TYPE, message.getMessageType().name());
195191
if (message instanceof UserMessage userMessage) {
@@ -208,8 +204,6 @@ else if (message instanceof AssistantMessage assistantMessage) {
208204
throw new RuntimeException("Unknown message type: " + message.getMessageType());
209205
})
210206
.toList();
211-
212-
return docs;
213207
}
214208

215209
/**

0 commit comments

Comments
 (0)