|
21 | 21 | import java.util.List; |
22 | 22 | import java.util.Map; |
23 | 23 |
|
| 24 | +import reactor.core.publisher.Flux; |
| 25 | +import reactor.core.publisher.Mono; |
24 | 26 | import reactor.core.scheduler.Scheduler; |
25 | 27 |
|
| 28 | +import org.springframework.ai.chat.client.ChatClientMessageAggregator; |
26 | 29 | import org.springframework.ai.chat.client.ChatClientRequest; |
27 | 30 | import org.springframework.ai.chat.client.ChatClientResponse; |
28 | 31 | import org.springframework.ai.chat.client.advisor.api.Advisor; |
29 | 32 | import org.springframework.ai.chat.client.advisor.api.AdvisorChain; |
30 | 33 | import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; |
31 | 34 | import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor; |
| 35 | +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; |
32 | 36 | import org.springframework.ai.chat.memory.ChatMemory; |
33 | 37 | import org.springframework.ai.chat.messages.AssistantMessage; |
34 | 38 | import org.springframework.ai.chat.messages.Message; |
@@ -167,6 +171,20 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh |
167 | 171 | return chatClientResponse; |
168 | 172 | } |
169 | 173 |
|
| 174 | + @Override |
| 175 | + public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, |
| 176 | + StreamAdvisorChain streamAdvisorChain) { |
| 177 | + // Get the scheduler from BaseAdvisor |
| 178 | + Scheduler scheduler = this.getScheduler(); |
| 179 | + // Process the request with the before method |
| 180 | + return Mono.just(chatClientRequest) |
| 181 | + .publishOn(scheduler) |
| 182 | + .map(request -> this.before(request, streamAdvisorChain)) |
| 183 | + .flatMapMany(streamAdvisorChain::nextStream) |
| 184 | + .transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux, |
| 185 | + response -> this.after(response, streamAdvisorChain))); |
| 186 | + } |
| 187 | + |
170 | 188 | private List<Document> toDocuments(List<Message> messages, String conversationId) { |
171 | 189 | List<Document> docs = messages.stream() |
172 | 190 | .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) |
|
0 commit comments