| 
16 | 16 | 
 
  | 
17 | 17 | package org.springframework.ai.chat.client.advisor;  | 
18 | 18 | 
 
  | 
 | 19 | +import java.util.List;  | 
 | 20 | + | 
19 | 21 | import org.junit.jupiter.api.Test;  | 
20 | 22 | import reactor.core.scheduler.Schedulers;  | 
21 | 23 | 
 
  | 
 | 24 | +import org.springframework.ai.chat.client.ChatClientResponse;  | 
22 | 25 | import org.springframework.ai.chat.client.advisor.api.Advisor;  | 
 | 26 | +import org.springframework.ai.chat.client.advisor.api.AdvisorChain;  | 
23 | 27 | import org.springframework.ai.chat.memory.ChatMemory;  | 
24 | 28 | import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;  | 
25 | 29 | import org.springframework.ai.chat.memory.MessageWindowChatMemory;  | 
 | 30 | +import org.springframework.ai.chat.messages.AssistantMessage;  | 
 | 31 | +import org.springframework.ai.chat.messages.Message;  | 
 | 32 | +import org.springframework.ai.chat.model.ChatResponse;  | 
 | 33 | +import org.springframework.ai.chat.model.Generation;  | 
26 | 34 | import org.springframework.ai.chat.prompt.PromptTemplate;  | 
27 | 35 | 
 
  | 
28 | 36 | import static org.assertj.core.api.Assertions.assertThat;  | 
29 | 37 | import static org.assertj.core.api.Assertions.assertThatThrownBy;  | 
 | 38 | +import static org.mockito.Mockito.mock;  | 
 | 39 | +import static org.mockito.Mockito.when;  | 
30 | 40 | 
 
  | 
31 | 41 | /**  | 
32 | 42 |  * Unit tests for {@link PromptChatMemoryAdvisor}.  | 
33 | 43 |  *  | 
34 | 44 |  * @author Mark Pollack  | 
35 | 45 |  * @author Thomas Vitale  | 
 | 46 | + * @author Soby Chacko  | 
36 | 47 |  */  | 
37 | 48 | public class PromptChatMemoryAdvisorTests {  | 
38 | 49 | 
 
  | 
@@ -138,4 +149,120 @@ void testDefaultValues() {  | 
138 | 149 | 		assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);  | 
139 | 150 | 	}  | 
140 | 151 | 
 
  | 
 | 152 | +	@Test  | 
 | 153 | +	void testAfterMethodHandlesSingleGeneration() {  | 
 | 154 | +		ChatMemory chatMemory = MessageWindowChatMemory.builder()  | 
 | 155 | +			.chatMemoryRepository(new InMemoryChatMemoryRepository())  | 
 | 156 | +			.build();  | 
 | 157 | + | 
 | 158 | +		PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)  | 
 | 159 | +			.conversationId("test-conversation")  | 
 | 160 | +			.build();  | 
 | 161 | + | 
 | 162 | +		ChatClientResponse mockResponse = mock(ChatClientResponse.class);  | 
 | 163 | +		ChatResponse mockChatResponse = mock(ChatResponse.class);  | 
 | 164 | +		Generation mockGeneration = mock(Generation.class);  | 
 | 165 | +		AdvisorChain mockChain = mock(AdvisorChain.class);  | 
 | 166 | + | 
 | 167 | +		when(mockResponse.chatResponse()).thenReturn(mockChatResponse);  | 
 | 168 | +		when(mockChatResponse.getResults()).thenReturn(List.of(mockGeneration)); // Single  | 
 | 169 | +																					// result  | 
 | 170 | +		when(mockGeneration.getOutput()).thenReturn(new AssistantMessage("Single response"));  | 
 | 171 | + | 
 | 172 | +		ChatClientResponse result = advisor.after(mockResponse, mockChain);  | 
 | 173 | + | 
 | 174 | +		assertThat(result).isEqualTo(mockResponse); // Should return the same response  | 
 | 175 | + | 
 | 176 | +		// Verify single message stored in memory  | 
 | 177 | +		List<Message> messages = chatMemory.get("test-conversation");  | 
 | 178 | +		assertThat(messages).hasSize(1);  | 
 | 179 | +		assertThat(messages.get(0).getText()).isEqualTo("Single response");  | 
 | 180 | +	}  | 
 | 181 | + | 
 | 182 | +	@Test  | 
 | 183 | +	void testAfterMethodHandlesMultipleGenerations() {  | 
 | 184 | +		ChatMemory chatMemory = MessageWindowChatMemory.builder()  | 
 | 185 | +			.chatMemoryRepository(new InMemoryChatMemoryRepository())  | 
 | 186 | +			.build();  | 
 | 187 | + | 
 | 188 | +		PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)  | 
 | 189 | +			.conversationId("test-conversation")  | 
 | 190 | +			.build();  | 
 | 191 | + | 
 | 192 | +		ChatClientResponse mockResponse = mock(ChatClientResponse.class);  | 
 | 193 | +		ChatResponse mockChatResponse = mock(ChatResponse.class);  | 
 | 194 | +		Generation mockGen1 = mock(Generation.class);  | 
 | 195 | +		Generation mockGen2 = mock(Generation.class);  | 
 | 196 | +		Generation mockGen3 = mock(Generation.class);  | 
 | 197 | +		AdvisorChain mockChain = mock(AdvisorChain.class);  | 
 | 198 | + | 
 | 199 | +		when(mockResponse.chatResponse()).thenReturn(mockChatResponse);  | 
 | 200 | +		when(mockChatResponse.getResults()).thenReturn(List.of(mockGen1, mockGen2, mockGen3)); // Multiple  | 
 | 201 | +																								// results  | 
 | 202 | +		when(mockGen1.getOutput()).thenReturn(new AssistantMessage("Response 1"));  | 
 | 203 | +		when(mockGen2.getOutput()).thenReturn(new AssistantMessage("Response 2"));  | 
 | 204 | +		when(mockGen3.getOutput()).thenReturn(new AssistantMessage("Response 3"));  | 
 | 205 | + | 
 | 206 | +		ChatClientResponse result = advisor.after(mockResponse, mockChain);  | 
 | 207 | + | 
 | 208 | +		assertThat(result).isEqualTo(mockResponse); // Should return the same response  | 
 | 209 | + | 
 | 210 | +		// Verify all messages were stored in memory  | 
 | 211 | +		List<Message> messages = chatMemory.get("test-conversation");  | 
 | 212 | +		assertThat(messages).hasSize(3);  | 
 | 213 | +		assertThat(messages.get(0).getText()).isEqualTo("Response 1");  | 
 | 214 | +		assertThat(messages.get(1).getText()).isEqualTo("Response 2");  | 
 | 215 | +		assertThat(messages.get(2).getText()).isEqualTo("Response 3");  | 
 | 216 | +	}  | 
 | 217 | + | 
 | 218 | +	@Test  | 
 | 219 | +	void testAfterMethodHandlesEmptyResults() {  | 
 | 220 | +		ChatMemory chatMemory = MessageWindowChatMemory.builder()  | 
 | 221 | +			.chatMemoryRepository(new InMemoryChatMemoryRepository())  | 
 | 222 | +			.build();  | 
 | 223 | + | 
 | 224 | +		PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)  | 
 | 225 | +			.conversationId("test-conversation")  | 
 | 226 | +			.build();  | 
 | 227 | + | 
 | 228 | +		ChatClientResponse mockResponse = mock(ChatClientResponse.class);  | 
 | 229 | +		ChatResponse mockChatResponse = mock(ChatResponse.class);  | 
 | 230 | +		AdvisorChain mockChain = mock(AdvisorChain.class);  | 
 | 231 | + | 
 | 232 | +		when(mockResponse.chatResponse()).thenReturn(mockChatResponse);  | 
 | 233 | +		when(mockChatResponse.getResults()).thenReturn(List.of());  | 
 | 234 | + | 
 | 235 | +		ChatClientResponse result = advisor.after(mockResponse, mockChain);  | 
 | 236 | + | 
 | 237 | +		assertThat(result).isEqualTo(mockResponse);  | 
 | 238 | + | 
 | 239 | +		// Verify no messages were stored in memory  | 
 | 240 | +		List<Message> messages = chatMemory.get("test-conversation");  | 
 | 241 | +		assertThat(messages).isEmpty();  | 
 | 242 | +	}  | 
 | 243 | + | 
 | 244 | +	@Test  | 
 | 245 | +	void testAfterMethodHandlesNullChatResponse() {  | 
 | 246 | +		ChatMemory chatMemory = MessageWindowChatMemory.builder()  | 
 | 247 | +			.chatMemoryRepository(new InMemoryChatMemoryRepository())  | 
 | 248 | +			.build();  | 
 | 249 | + | 
 | 250 | +		PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)  | 
 | 251 | +			.conversationId("test-conversation")  | 
 | 252 | +			.build();  | 
 | 253 | + | 
 | 254 | +		ChatClientResponse mockResponse = mock(ChatClientResponse.class);  | 
 | 255 | +		AdvisorChain mockChain = mock(AdvisorChain.class);  | 
 | 256 | + | 
 | 257 | +		when(mockResponse.chatResponse()).thenReturn(null);  | 
 | 258 | + | 
 | 259 | +		ChatClientResponse result = advisor.after(mockResponse, mockChain);  | 
 | 260 | + | 
 | 261 | +		assertThat(result).isEqualTo(mockResponse);  | 
 | 262 | + | 
 | 263 | +		// Verify no messages were stored in memory  | 
 | 264 | +		List<Message> messages = chatMemory.get("test-conversation");  | 
 | 265 | +		assertThat(messages).isEmpty();  | 
 | 266 | +	}  | 
 | 267 | + | 
141 | 268 | }  | 
0 commit comments