Skip to content

Commit de7bf3c

Browse files
feat: [Orchestration] Chat Memory (#367)
* feat: [Orchestration] Chat Memory * release notes * Absolute link * updated docs
1 parent 8fefe94 commit de7bf3c

File tree

9 files changed

+154
-1
lines changed

9 files changed

+154
-1
lines changed

docs/guides/SPRING_AI_INTEGRATION.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- [Orchestration Masking](#orchestration-masking)
88
- [Stream chat completion](#stream-chat-completion)
99
- [Tool Calling](#tool-calling)
10+
- [Chat Memory](#chat-memory)
1011

1112
## Introduction
1213

@@ -137,3 +138,26 @@ ChatResponse response = client.call(prompt);
137138

138139
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).
139140

141+
## Chat Memory
142+
143+
Create a Spring AI `ChatClient` from our `OrchestrationChatModel` and add a chat memory advisor like so:
144+
145+
```java
146+
ChatModel client = new OrchestrationChatModel();
147+
OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO);
148+
OrchestrationChatOptions opts = new OrchestrationChatOptions(config);
149+
150+
val memory = new InMemoryChatMemory();
151+
val advisor = new MessageChatMemoryAdvisor(memory);
152+
val cl = ChatClient.builder(client).defaultAdvisors(advisor).build();
153+
154+
Prompt prompt1 = new Prompt("What is the capital of France?", defaultOptions);
155+
String content1 = cl.prompt(prompt1).call().content();
156+
// content1 is "Paris"
157+
158+
Prompt prompt2 = new Prompt("And what is the typical food there?", defaultOptions);
159+
String content2 = cl.prompt(prompt2).call().content();
160+
// chat memory will remember that the user is inquiring about France.
161+
```
162+
163+
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java).

docs/release-notes/release_notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
### ✨ New Functionality
1414

15+
- [Orchestration] [Add Spring AI Chat Memory support](https://github.com/SAP/ai-sdk-java/tree/main/docs/guides/SPRING_AI_INTEGRATION.md#chat-memory)
1516
- [Orchestration] [Prompt templates can be consumed from registry.](https://github.com/SAP/ai-sdk-java/tree/main/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md#Chat-completion-with-Templates)
1617

1718
### 📈 Improvements

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ private static com.sap.ai.sdk.orchestration.Message[] toOrchestrationMessages(
129129
case ASSISTANT:
130130
val springToolCalls =
131131
((org.springframework.ai.chat.messages.AssistantMessage) msg).getToolCalls();
132-
if (springToolCalls != null) {
132+
if (springToolCalls != null && !springToolCalls.isEmpty()) {
133133
final List<ResponseMessageToolCall> sdkToolCalls =
134134
springToolCalls.stream()
135135
.map(OrchestrationChatModel::toOrchestrationToolCall)

orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModelTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
import org.junit.jupiter.api.BeforeEach;
4040
import org.junit.jupiter.api.Test;
4141
import org.mockito.Mockito;
42+
import org.springframework.ai.chat.client.ChatClient;
43+
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
44+
import org.springframework.ai.chat.memory.InMemoryChatMemory;
4245
import org.springframework.ai.chat.messages.AssistantMessage.ToolCall;
4346
import org.springframework.ai.chat.model.ChatResponse;
4447
import org.springframework.ai.chat.prompt.Prompt;
@@ -213,4 +216,41 @@ void testToolCallsWithExecution() throws IOException {
213216
}
214217
}
215218
}
219+
220+
@Test
221+
void testChatMemory() throws IOException {
222+
stubFor(
223+
post(urlPathEqualTo("/completion"))
224+
.inScenario("Chat Memory")
225+
.whenScenarioStateIs(STARTED)
226+
.willReturn(
227+
aResponse()
228+
.withBodyFile("templatingResponse.json") // The response is not important
229+
.withHeader("Content-Type", "application/json"))
230+
.willSetStateTo("Second Call"));
231+
232+
stubFor(
233+
post(urlPathEqualTo("/completion"))
234+
.inScenario("Chat Memory")
235+
.whenScenarioStateIs("Second Call")
236+
.willReturn(
237+
aResponse()
238+
.withBodyFile("templatingResponse.json") // The response is not important
239+
.withHeader("Content-Type", "application/json")));
240+
241+
val memory = new InMemoryChatMemory();
242+
val advisor = new MessageChatMemoryAdvisor(memory);
243+
val cl = ChatClient.builder(client).defaultAdvisors(advisor).build();
244+
val prompt1 = new Prompt("What is the capital of France?", defaultOptions);
245+
val prompt2 = new Prompt("And what is the typical food there?", defaultOptions);
246+
247+
cl.prompt(prompt1).call().content();
248+
cl.prompt(prompt2).call().content();
249+
// The response is not important
250+
// We just want to verify that the second call remembered the first call
251+
try (var requestInputStream = fileLoader.apply("chatMemory.json")) {
252+
final String request = new String(requestInputStream.readAllBytes());
253+
verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request)));
254+
}
255+
}
216256
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"orchestration_config": {
3+
"module_configurations": {
4+
"llm_module_config": {
5+
"model_name" : "gpt-4o",
6+
"model_params": {},
7+
"model_version": "latest"
8+
},
9+
"templating_module_config": {
10+
"template": [
11+
{
12+
"role": "user",
13+
"content": "What is the capital of France?"
14+
},
15+
{
16+
"role": "assistant",
17+
"content" : "Le service d'orchestration fonctionne!"
18+
},
19+
{
20+
"role": "user",
21+
"content": "And what is the typical food there?"
22+
}
23+
],
24+
"defaults": {},
25+
"tools": []
26+
}
27+
},
28+
"stream": false
29+
},
30+
"input_params": {},
31+
"messages_history": []
32+
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,17 @@ Object toolCalling(
7979
final String text = message.getText();
8080
return text.isEmpty() ? message.getToolCalls().toString() : text;
8181
}
82+
83+
@GetMapping("/chatMemory")
84+
Object chatMemory(
85+
@Nullable @RequestParam(value = "format", required = false) final String format) {
86+
val response = service.chatMemory();
87+
88+
if ("json".equals(format)) {
89+
return ((OrchestrationSpringChatResponse) response)
90+
.getOrchestrationResponse()
91+
.getOriginalResponse();
92+
}
93+
return response.getResult().getOutput().getText();
94+
}
8295
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOrchestrationService.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
import com.sap.ai.sdk.orchestration.spring.OrchestrationChatOptions;
1010
import java.util.List;
1111
import java.util.Map;
12+
import java.util.Objects;
1213
import javax.annotation.Nonnull;
1314
import lombok.val;
15+
import org.springframework.ai.chat.client.ChatClient;
16+
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
17+
import org.springframework.ai.chat.memory.InMemoryChatMemory;
1418
import org.springframework.ai.chat.model.ChatModel;
1519
import org.springframework.ai.chat.model.ChatResponse;
1620
import org.springframework.ai.chat.prompt.Prompt;
@@ -106,4 +110,22 @@ public ChatResponse toolCalling(final boolean internalToolExecutionEnabled) {
106110
val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options);
107111
return client.call(prompt);
108112
}
113+
114+
/**
115+
* Chat request to OpenAI through the Orchestration service using chat memory.
116+
*
117+
* @return the assistant response object
118+
*/
119+
@Nonnull
120+
public ChatResponse chatMemory() {
121+
val memory = new InMemoryChatMemory();
122+
val advisor = new MessageChatMemoryAdvisor(memory);
123+
val cl = ChatClient.builder(client).defaultAdvisors(advisor).build();
124+
val prompt1 = new Prompt("What is the capital of France?", defaultOptions);
125+
val prompt2 = new Prompt("And what is the typical food there?", defaultOptions);
126+
127+
cl.prompt(prompt1).call().content();
128+
return Objects.requireNonNull(
129+
cl.prompt(prompt2).call().chatResponse(), "Chat response is null");
130+
}
109131
}

sample-code/spring-app/src/main/resources/static/index.html

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,16 @@ <h5 class="mb-1">Orchestration Integration</h5>
596596
</div>
597597
</div>
598598
</li>
599+
<li class="list-group-item">
600+
<div class="info-tooltip">
601+
<button type="submit" formaction="/spring-ai-orchestration/chatMemory"
602+
class="link-offset-2-hover link-underline link-underline-opacity-0 link-underline-opacity-75-hover endpoint">
603+
<code>/spring-ai-orchestration/chatMemory</code>
604+
</button>
605+
<div class="tooltip-content">
606+
The user firsts asks the capital of France, then the typical for there, chat memory will remember that the user is inquiring about France.
607+
</div>
608+
</div>
599609
</ul>
600610
</div>
601611
</div>

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,15 @@ void testToolCallingWithExecution() {
8282
.isExactlyInstanceOf(OrchestrationClientException.class)
8383
.hasMessageContaining("Request failed with status 400 Bad Request");
8484
}
85+
86+
@Test
87+
void testChatMemory() {
88+
ChatResponse response = service.chatMemory();
89+
assertThat(response).isNotNull();
90+
String text = response.getResult().getOutput().getText();
91+
log.info(text);
92+
assertThat(text)
93+
.containsAnyOf(
94+
"French", "onion", "pastries", "cheese", "baguette", "coq au vin", "foie gras");
95+
}
8596
}

0 commit comments

Comments
 (0)