Skip to content

Commit 7ec7dff

Browse files
CharlesDuboisSAPbot-sdk-jsnewtork
authored
feat: Spring AI Orchestration streaming (#292)
* Added stream to Spring AI Orchestration * Fixed * Formatting * Added e2e test + docs * Removed unused field * Update orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatDelta.java Co-authored-by: Alexander Dümont <[email protected]> * Remove comments * Fixed mistake * Fixed mistake --------- Co-authored-by: SAP Cloud SDK Bot <[email protected]> Co-authored-by: Alexander Dümont <[email protected]>
1 parent a7b2b87 commit 7ec7dff

File tree

15 files changed

+407
-33
lines changed

15 files changed

+407
-33
lines changed

docs/guides/SPRING_AI_INTEGRATION.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ Prompt prompt = new Prompt("What is the capital of France?", opts);
5050
ChatResponse response = client.call(prompt);
5151
```
5252

53-
Please
54-
find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java).
53+
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java).
5554

5655
## Orchestration Masking
5756

@@ -76,3 +75,27 @@ ChatResponse response = client.call(prompt);
7675

7776
Please
7877
find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java).
78+
79+
## Stream chat completion
80+
81+
It's possible to pass a stream of chat completion delta elements, e.g. from the application backend
82+
to the frontend in real-time.
83+
84+
```java
85+
ChatModel client = new OrchestrationChatModel();
86+
OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO);
87+
OrchestrationChatOptions opts = new OrchestrationChatOptions(config);
88+
89+
Prompt prompt =
90+
new Prompt(
91+
"Can you give me the first 100 numbers of the Fibonacci sequence?", opts);
92+
Flux<ChatResponse> flux = client.stream(prompt);
93+
94+
// also possible to keep only the chat completion text
95+
Flux<String> responseFlux =
96+
flux.map(chatResponse -> chatResponse.getResult().getOutput().getContent());
97+
```
98+
99+
_Note: A Spring endpoint can return `Flux` instead of `ResponseEntity`._
100+
101+
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java).

orchestration/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@
5959
<artifactId>spring-ai-core</artifactId>
6060
<optional>true</optional>
6161
</dependency>
62+
<dependency>
63+
<groupId>io.projectreactor</groupId>
64+
<artifactId>reactor-core</artifactId>
65+
<optional>true</optional>
66+
</dependency>
6267
<dependency>
6368
<groupId>org.apache.httpcomponents.core5</groupId>
6469
<artifactId>httpcore5</artifactId>

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
package com.sap.ai.sdk.orchestration.spring;
22

3+
import static com.sap.ai.sdk.orchestration.OrchestrationClient.toCompletionPostRequest;
4+
35
import com.google.common.annotations.Beta;
46
import com.sap.ai.sdk.orchestration.AssistantMessage;
7+
import com.sap.ai.sdk.orchestration.OrchestrationChatCompletionDelta;
58
import com.sap.ai.sdk.orchestration.OrchestrationClient;
9+
import com.sap.ai.sdk.orchestration.OrchestrationClientException;
610
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
711
import com.sap.ai.sdk.orchestration.SystemMessage;
812
import com.sap.ai.sdk.orchestration.UserMessage;
913
import java.util.List;
1014
import java.util.Map;
1115
import java.util.function.Function;
16+
import java.util.stream.Stream;
1217
import javax.annotation.Nonnull;
1318
import lombok.RequiredArgsConstructor;
1419
import lombok.extern.slf4j.Slf4j;
@@ -17,6 +22,7 @@
1722
import org.springframework.ai.chat.model.ChatModel;
1823
import org.springframework.ai.chat.model.ChatResponse;
1924
import org.springframework.ai.chat.prompt.Prompt;
25+
import reactor.core.publisher.Flux;
2026

2127
/**
2228
* Spring AI integration for the orchestration service.
@@ -52,6 +58,47 @@ public ChatResponse call(@Nonnull final Prompt prompt) {
5258
"Please add OrchestrationChatOptions to the Prompt: new Prompt(\"message\", new OrchestrationChatOptions(config))");
5359
}
5460

61+
@Override
62+
@Nonnull
63+
public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
64+
65+
if (prompt.getOptions() instanceof OrchestrationChatOptions options) {
66+
67+
val orchestrationPrompt = toOrchestrationPrompt(prompt);
68+
val request = toCompletionPostRequest(orchestrationPrompt, options.getConfig());
69+
val stream = client.streamChatCompletionDeltas(request);
70+
71+
final Flux<OrchestrationChatCompletionDelta> flux =
72+
Flux.generate(
73+
stream::iterator,
74+
(iterator, sink) -> {
75+
if (iterator.hasNext()) {
76+
sink.next(iterator.next());
77+
} else {
78+
sink.complete();
79+
}
80+
return iterator;
81+
});
82+
return flux.map(
83+
delta -> {
84+
throwOnContentFilter(stream, delta);
85+
return new OrchestrationSpringChatDelta(delta);
86+
});
87+
}
88+
throw new IllegalArgumentException(
89+
"Please add OrchestrationChatOptions to the Prompt: new Prompt(\"message\", new OrchestrationChatOptions(config))");
90+
}
91+
92+
private static void throwOnContentFilter(
93+
@Nonnull final Stream<OrchestrationChatCompletionDelta> stream,
94+
@Nonnull final OrchestrationChatCompletionDelta delta) {
95+
final String finishReason = delta.getFinishReason();
96+
if (finishReason != null && finishReason.equals("content_filter")) {
97+
stream.close();
98+
throw new OrchestrationClientException("Content filter filtered the output.");
99+
}
100+
}
101+
55102
@Nonnull
56103
private OrchestrationPrompt toOrchestrationPrompt(@Nonnull final Prompt prompt) {
57104
val messages = toOrchestrationMessages(prompt.getInstructions());
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import com.google.common.annotations.Beta;
4+
import com.sap.ai.sdk.orchestration.OrchestrationChatCompletionDelta;
5+
import com.sap.ai.sdk.orchestration.model.LLMChoice;
6+
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
7+
import com.sap.ai.sdk.orchestration.model.TokenUsage;
8+
import java.util.List;
9+
import java.util.Map;
10+
import javax.annotation.Nonnull;
11+
import lombok.EqualsAndHashCode;
12+
import lombok.Value;
13+
import lombok.val;
14+
import org.springframework.ai.chat.messages.AssistantMessage;
15+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
16+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
17+
import org.springframework.ai.chat.metadata.DefaultUsage;
18+
import org.springframework.ai.chat.model.ChatResponse;
19+
import org.springframework.ai.chat.model.Generation;
20+
21+
/**
22+
* Response from the orchestration service in a Spring AI {@link ChatResponse}.
23+
*
24+
* @since 1.2.0
25+
*/
26+
@Beta
27+
@Value
28+
@EqualsAndHashCode(callSuper = true)
29+
public class OrchestrationSpringChatDelta extends ChatResponse {
30+
31+
OrchestrationSpringChatDelta(@Nonnull final OrchestrationChatCompletionDelta delta) {
32+
super(
33+
toGenerations((LLMModuleResultSynchronous) delta.getOrchestrationResult()),
34+
toChatResponseMetadata((LLMModuleResultSynchronous) delta.getOrchestrationResult()));
35+
}
36+
37+
@Nonnull
38+
static List<Generation> toGenerations(@Nonnull final LLMModuleResultSynchronous result) {
39+
return result.getChoices().stream().map(OrchestrationSpringChatDelta::toGeneration).toList();
40+
}
41+
42+
@Nonnull
43+
static Generation toGeneration(@Nonnull final LLMChoice choice) {
44+
val metadata = ChatGenerationMetadata.builder().finishReason(choice.getFinishReason());
45+
metadata.metadata("index", choice.getIndex());
46+
if (!choice.getLogprobs().isEmpty()) {
47+
metadata.metadata("logprobs", choice.getLogprobs());
48+
}
49+
return new Generation(new AssistantMessage(getContent(choice)), metadata.build());
50+
}
51+
52+
@Nonnull
53+
private static String getContent(@Nonnull final LLMChoice choice) {
54+
return choice.getCustomField("delta") instanceof Map<?, ?> delta
55+
&& delta.get("content") instanceof String content
56+
? content
57+
: "";
58+
}
59+
60+
@Nonnull
61+
static ChatResponseMetadata toChatResponseMetadata(
62+
@Nonnull final LLMModuleResultSynchronous orchestrationResult) {
63+
val metadataBuilder = ChatResponseMetadata.builder();
64+
65+
metadataBuilder
66+
.id(orchestrationResult.getId())
67+
.model(orchestrationResult.getModel())
68+
.keyValue("object", orchestrationResult.getObject())
69+
.keyValue("created", orchestrationResult.getCreated());
70+
if (orchestrationResult.getUsage() != null) {
71+
metadataBuilder.usage(toDefaultUsage(orchestrationResult.getUsage()));
72+
}
73+
return metadataBuilder.build();
74+
}
75+
76+
@Nonnull
77+
private static DefaultUsage toDefaultUsage(@Nonnull final TokenUsage usage) {
78+
return new DefaultUsage(
79+
usage.getPromptTokens().longValue(),
80+
usage.getCompletionTokens().longValue(),
81+
usage.getTotalTokens().longValue());
82+
}
83+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringChatResponse.java

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import com.sap.ai.sdk.orchestration.model.LLMChoice;
66
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
77
import com.sap.ai.sdk.orchestration.model.TokenUsage;
8-
import java.util.HashMap;
98
import java.util.List;
10-
import java.util.Map;
119
import javax.annotation.Nonnull;
1210
import lombok.EqualsAndHashCode;
1311
import lombok.Value;
1412
import lombok.val;
1513
import org.springframework.ai.chat.messages.AssistantMessage;
14+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
1615
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
1716
import org.springframework.ai.chat.metadata.DefaultUsage;
1817
import org.springframework.ai.chat.model.ChatResponse;
@@ -43,21 +42,18 @@ public class OrchestrationSpringChatResponse extends ChatResponse {
4342

4443
@Nonnull
4544
static List<Generation> toGenerations(@Nonnull final LLMModuleResultSynchronous result) {
46-
return result.getChoices().stream()
47-
.map(OrchestrationSpringChatResponse::toAssistantMessage)
48-
.map(Generation::new)
49-
.toList();
45+
return result.getChoices().stream().map(OrchestrationSpringChatResponse::toGeneration).toList();
5046
}
5147

5248
@Nonnull
53-
static AssistantMessage toAssistantMessage(@Nonnull final LLMChoice choice) {
54-
final Map<String, Object> metadata = new HashMap<>();
55-
metadata.put("finish_reason", choice.getFinishReason());
56-
metadata.put("index", choice.getIndex());
49+
static Generation toGeneration(@Nonnull final LLMChoice choice) {
50+
val metadata = ChatGenerationMetadata.builder().finishReason(choice.getFinishReason());
51+
metadata.metadata("index", choice.getIndex());
5752
if (!choice.getLogprobs().isEmpty()) {
58-
metadata.put("logprobs", choice.getLogprobs());
53+
metadata.metadata("logprobs", choice.getLogprobs());
5954
}
60-
return new AssistantMessage(choice.getMessage().getContent(), metadata);
55+
val message = new AssistantMessage(choice.getMessage().getContent());
56+
return new Generation(message, metadata.build());
6157
}
6258

6359
@Nonnull
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import com.sap.ai.sdk.orchestration.model.LLMChoice;
6+
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
7+
import com.sap.ai.sdk.orchestration.model.ResponseChatMessage;
8+
import com.sap.ai.sdk.orchestration.model.TokenUsage;
9+
import java.util.List;
10+
import java.util.Map;
11+
import org.junit.jupiter.api.Test;
12+
import org.springframework.ai.chat.metadata.EmptyUsage;
13+
import org.springframework.ai.chat.model.Generation;
14+
15+
class OrchestrationChatDeltaTest {
16+
17+
@Test
18+
void testToGeneration() {
19+
var choice =
20+
LLMChoice.create()
21+
.index(0)
22+
.message(ResponseChatMessage.create().role("wrong").content("wrong"))
23+
.finishReason("stop");
24+
// this will be fixed once the spec is fixed
25+
choice.setCustomField("delta", Map.of("content", "Hello, world!"));
26+
27+
Generation generation = OrchestrationSpringChatDelta.toGeneration(choice);
28+
29+
assertThat(generation.getOutput().getContent()).isEqualTo("Hello, world!");
30+
assertThat(generation.getMetadata().getFinishReason()).isEqualTo("stop");
31+
assertThat(generation.getMetadata().<Integer>get("index")).isEqualTo(0);
32+
}
33+
34+
@Test
35+
void testToChatResponseMetadata() {
36+
var moduleResult =
37+
LLMModuleResultSynchronous.create()
38+
.id("test-id")
39+
._object("test-object")
40+
.created(123456789)
41+
.model("test-model")
42+
.choices(List.of())
43+
.usage(TokenUsage.create().completionTokens(20).promptTokens(10).totalTokens(30));
44+
45+
var metadata = OrchestrationSpringChatDelta.toChatResponseMetadata(moduleResult);
46+
47+
assertThat(metadata.getId()).isEqualTo("test-id");
48+
assertThat(metadata.getModel()).isEqualTo("test-model");
49+
assertThat(metadata.<String>get("object")).isEqualTo("test-object");
50+
assertThat(metadata.<Integer>get("created")).isEqualTo(123456789);
51+
52+
var usage = metadata.getUsage();
53+
54+
assertThat(usage.getPromptTokens()).isEqualTo(10L);
55+
assertThat(usage.getGenerationTokens()).isEqualTo(20L);
56+
assertThat(usage.getTotalTokens()).isEqualTo(30L);
57+
58+
// delta without token usage
59+
moduleResult.usage(null);
60+
metadata = OrchestrationSpringChatDelta.toChatResponseMetadata(moduleResult);
61+
assertThat(metadata.getUsage()).isInstanceOf(EmptyUsage.class);
62+
}
63+
}

0 commit comments

Comments
 (0)