Skip to content

Commit c326b7d

Browse files
Orchestration streaming (#227)
* Orchestration streaming first version * Added unit tests * Formatting * Added documentation * Added tests * Release notes * Applied Alex's review comments --------- Co-authored-by: SAP Cloud SDK Bot <[email protected]>
1 parent a174ba8 commit c326b7d

File tree

17 files changed

+660
-15
lines changed

17 files changed

+660
-15
lines changed

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,32 @@ Use the grounding module to provide additional context to the AI model.
217217

218218
In this example, the AI model is provided with additional context in the form of grounding information. Note, that it is necessary to provide the grounding input via one or more input variables.
219219

220+
### Stream chat completion
221+
222+
It's possible to pass a stream of chat completion delta elements, e.g. from the application backend to the frontend in real-time.
223+
224+
#### Asynchronous Streaming
225+
226+
This is a blocking example for streaming and printing directly to the console:
227+
228+
```java
229+
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";
230+
231+
// try-with-resources on stream ensures the connection will be closed
232+
try (Stream<String> stream = client.streamChatCompletion(prompt, config)) {
233+
stream.forEach(
234+
deltaString -> {
235+
System.out.print(deltaString);
236+
System.out.flush();
237+
});
238+
}
239+
```
240+
241+
#### Spring Boot example
242+
243+
Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java).
244+
It shows the usage of Spring Boot's `ResponseBodyEmitter` to stream the chat completion delta messages to the frontend in real-time.
245+
220246
### Set model parameters
221247

222248
Change your LLM configuration to add model parameters:

docs/release-notes/release_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
### ✨ New Functionality
1414

15-
-
15+
- Added `streamChatCompletion()` and `streamChatCompletionDeltas()` to the `OrchestrationClient`.
1616

1717
### 📈 Improvements
1818

orchestration/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@
112112
<artifactId>mockito-core</artifactId>
113113
<scope>test</scope>
114114
</dependency>
115+
<dependency>
116+
<groupId>org.junit.jupiter</groupId>
117+
<artifactId>junit-jupiter-params</artifactId>
118+
<scope>test</scope>
119+
</dependency>
115120
</dependencies>
116121

117122
<profiles>
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static java.nio.charset.StandardCharsets.UTF_8;
4+
import static java.util.Spliterator.NONNULL;
5+
import static java.util.Spliterator.ORDERED;
6+
7+
import io.vavr.control.Try;
8+
import java.io.BufferedReader;
9+
import java.io.IOException;
10+
import java.io.InputStream;
11+
import java.io.InputStreamReader;
12+
import java.util.Iterator;
13+
import java.util.NoSuchElementException;
14+
import java.util.Spliterators;
15+
import java.util.concurrent.Callable;
16+
import java.util.function.Function;
17+
import java.util.stream.Stream;
18+
import java.util.stream.StreamSupport;
19+
import javax.annotation.Nonnull;
20+
import javax.annotation.Nullable;
21+
import lombok.AccessLevel;
22+
import lombok.RequiredArgsConstructor;
23+
import lombok.extern.slf4j.Slf4j;
24+
import org.apache.hc.core5.http.HttpEntity;
25+
26+
/**
27+
* Internal utility class to convert from a reading handler to {@link Iterable} and {@link Stream}.
28+
*
29+
* <p><strong>Note:</strong> All operations are sequential in nature. Thread safety is not
30+
* guaranteed.
31+
*
32+
* @param <T> Iterated item type.
33+
*/
34+
@Slf4j
35+
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
36+
class IterableStreamConverter<T> implements Iterator<T> {
37+
/** see DEFAULT_CHAR_BUFFER_SIZE in {@link BufferedReader} * */
38+
static final int BUFFER_SIZE = 8192;
39+
40+
/** Read next entry for Stream or {@code null} when no further entry can be read. */
41+
private final Callable<T> readHandler;
42+
43+
/** Close handler to be called when Stream terminated. */
44+
private final Runnable stopHandler;
45+
46+
/** Error handler to be called when Stream is interrupted. */
47+
private final Function<Exception, RuntimeException> errorHandler;
48+
49+
private boolean isDone = false;
50+
private boolean isNextFetched = false;
51+
private T next = null;
52+
53+
@SuppressWarnings("checkstyle:IllegalCatch")
54+
@Override
55+
public boolean hasNext() {
56+
if (isDone) {
57+
return false;
58+
}
59+
if (isNextFetched) {
60+
return true;
61+
}
62+
try {
63+
next = readHandler.call();
64+
isNextFetched = true;
65+
if (next == null) {
66+
isDone = true;
67+
stopHandler.run();
68+
}
69+
} catch (final Exception e) {
70+
isDone = true;
71+
stopHandler.run();
72+
log.debug("Error while reading next element.", e);
73+
throw errorHandler.apply(e);
74+
}
75+
return !isDone;
76+
}
77+
78+
@Override
79+
public T next() {
80+
if (next == null && !hasNext()) {
81+
throw new NoSuchElementException(); // normally not reached with Stream API
82+
}
83+
isNextFetched = false;
84+
return next;
85+
}
86+
87+
/**
88+
* Create a sequential Stream of lines from an HTTP response string (UTF-8). The underlying {@link
89+
* InputStream} is closed, when the resulting Stream is closed (e.g. via try-with-resources) or
90+
* when an exception occurred.
91+
*
92+
* @param entity The HTTP entity object.
93+
* @return A sequential Stream object.
94+
* @throws OrchestrationClientException if the provided HTTP entity object is {@code null} or
95+
* empty.
96+
*/
97+
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed
98+
@Nonnull
99+
static Stream<String> lines(@Nullable final HttpEntity entity)
100+
throws OrchestrationClientException {
101+
if (entity == null) {
102+
throw new OrchestrationClientException("Orchestration service response was empty.");
103+
}
104+
105+
final InputStream inputStream;
106+
try {
107+
inputStream = entity.getContent();
108+
} catch (final IOException e) {
109+
throw new OrchestrationClientException("Failed to read response content.", e);
110+
}
111+
112+
final var reader = new BufferedReader(new InputStreamReader(inputStream, UTF_8), BUFFER_SIZE);
113+
final Runnable closeHandler =
114+
() -> Try.run(reader::close).onFailure(e -> log.error("Could not close input stream", e));
115+
final Function<Exception, RuntimeException> errHandler =
116+
e -> new OrchestrationClientException("Parsing response content was interrupted.", e);
117+
118+
final var iterator = new IterableStreamConverter<>(reader::readLine, closeHandler, errHandler);
119+
final var spliterator = Spliterators.spliteratorUnknownSize(iterator, ORDERED | NONNULL);
120+
return StreamSupport.stream(spliterator, /* NOT PARALLEL */ false).onClose(closeHandler);
121+
}
122+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.model.CompletionPostResponse;
4+
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
5+
import java.util.Map;
6+
import javax.annotation.Nonnull;
7+
import javax.annotation.Nullable;
8+
import lombok.val;
9+
10+
/** Orchestration chat completion output delta for streaming. */
11+
public class OrchestrationChatCompletionDelta extends CompletionPostResponse
12+
implements StreamedDelta {
13+
14+
@Nonnull
15+
@Override
16+
// will be fixed once the generated code add a discriminator which will allow this class to extend
17+
// CompletionPostResponseStreaming
18+
@SuppressWarnings("unchecked")
19+
public String getDeltaContent() {
20+
val choices = ((LLMModuleResultSynchronous) getOrchestrationResult()).getChoices();
21+
// Avoid the first delta: "choices":[]
22+
if (!choices.isEmpty()
23+
// Multiple choices are spread out on multiple deltas
24+
// A delta only contains one choice with a variable index
25+
&& choices.get(0).getIndex() == 0) {
26+
27+
final var message = (Map<String, Object>) choices.get(0).getCustomField("delta");
28+
// Avoid the second delta: "choices":[{"delta":{"content":"","role":"assistant"}}]
29+
if (message != null && message.get("content") != null) {
30+
return message.get("content").toString();
31+
}
32+
}
33+
return "";
34+
}
35+
36+
@Nullable
37+
@Override
38+
public String getFinishReason() {
39+
return ((LLMModuleResultSynchronous) getOrchestrationResult())
40+
.getChoices()
41+
.get(0)
42+
.getFinishReason();
43+
}
44+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
2323
import java.io.IOException;
2424
import java.util.function.Supplier;
25+
import java.util.stream.Stream;
2526
import javax.annotation.Nonnull;
2627
import lombok.extern.slf4j.Slf4j;
2728
import lombok.val;
2829
import org.apache.hc.client5.http.classic.methods.HttpPost;
2930
import org.apache.hc.core5.http.ContentType;
3031
import org.apache.hc.core5.http.io.entity.StringEntity;
32+
import org.apache.hc.core5.http.message.BasicClassicHttpRequest;
3133

3234
/** Client to execute requests to the orchestration service. */
3335
@Slf4j
@@ -105,6 +107,33 @@ public OrchestrationChatResponse chatCompletion(
105107
return new OrchestrationChatResponse(executeRequest(request));
106108
}
107109

110+
/**
111+
* Generate a completion for the given prompt.
112+
*
113+
* @param prompt a text message.
114+
* @return A stream of message deltas
115+
* @throws OrchestrationClientException if the request fails or if the finish reason is
116+
* content_filter
117+
* @since 1.1.0
118+
*/
119+
@Nonnull
120+
public Stream<String> streamChatCompletion(
121+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
122+
throws OrchestrationClientException {
123+
124+
val request = toCompletionPostRequest(prompt, config);
125+
return streamChatCompletionDeltas(request)
126+
.peek(OrchestrationClient::throwOnContentFilter)
127+
.map(OrchestrationChatCompletionDelta::getDeltaContent);
128+
}
129+
130+
private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) {
131+
final String finishReason = delta.getFinishReason();
132+
if (finishReason != null && finishReason.equals("content_filter")) {
133+
throw new OrchestrationClientException("Content filter filtered the output.");
134+
}
135+
}
136+
108137
/**
109138
* Serializes the given request, executes it and deserializes the response.
110139
*
@@ -205,4 +234,53 @@ CompletionPostResponse executeRequest(@Nonnull final String request) {
205234
throw new OrchestrationClientException("Failed to execute request", e);
206235
}
207236
}
237+
238+
/**
239+
* Generate a completion for the given prompt.
240+
*
241+
* @param request the prompt, including messages and other parameters.
242+
* @return A stream of chat completion delta elements.
243+
* @throws OrchestrationClientException if the request fails
244+
* @since 1.1.0
245+
*/
246+
@Nonnull
247+
public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
248+
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
249+
request.getOrchestrationConfig().setStream(true);
250+
return executeStream("/completion", request, OrchestrationChatCompletionDelta.class);
251+
}
252+
253+
@Nonnull
254+
private <D extends StreamedDelta> Stream<D> executeStream(
255+
@Nonnull final String path,
256+
@Nonnull final Object payload,
257+
@Nonnull final Class<D> deltaType) {
258+
final var request = new HttpPost(path);
259+
serializeAndSetHttpEntity(request, payload);
260+
return streamRequest(request, deltaType);
261+
}
262+
263+
private static void serializeAndSetHttpEntity(
264+
@Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) {
265+
try {
266+
final var json = JACKSON.writeValueAsString(payload);
267+
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
268+
} catch (final JsonProcessingException e) {
269+
throw new OrchestrationClientException("Failed to serialize request parameters", e);
270+
}
271+
}
272+
273+
@Nonnull
274+
private <D extends StreamedDelta> Stream<D> streamRequest(
275+
final BasicClassicHttpRequest request, @Nonnull final Class<D> deltaType) {
276+
try {
277+
val destination = destinationSupplier.get();
278+
log.debug("Using destination {} to connect to orchestration service", destination);
279+
val client = ApacheHttpClient5Accessor.getHttpClient(destination);
280+
return new OrchestrationStreamingHandler<>(deltaType)
281+
.handleResponse(client.executeOpen(null, request, null));
282+
} catch (final IOException e) {
283+
throw new OrchestrationClientException("Request to the Orchestration service failed", e);
284+
}
285+
}
208286
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static com.sap.ai.sdk.orchestration.OrchestrationClient.JACKSON;
4+
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.buildExceptionAndThrow;
5+
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.parseErrorAndThrow;
6+
7+
import java.io.IOException;
8+
import java.util.stream.Stream;
9+
import javax.annotation.Nonnull;
10+
import lombok.RequiredArgsConstructor;
11+
import lombok.extern.slf4j.Slf4j;
12+
import org.apache.hc.core5.http.ClassicHttpResponse;
13+
14+
@Slf4j
15+
@RequiredArgsConstructor
16+
class OrchestrationStreamingHandler<D extends StreamedDelta> {
17+
18+
@Nonnull private final Class<D> deltaType;
19+
20+
/**
21+
* @param response The response to process
22+
* @return A {@link Stream} of a model class instantiated from the response
23+
*/
24+
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed
25+
@Nonnull
26+
Stream<D> handleResponse(@Nonnull final ClassicHttpResponse response)
27+
throws OrchestrationClientException {
28+
if (response.getCode() >= 300) {
29+
buildExceptionAndThrow(response);
30+
}
31+
return IterableStreamConverter.lines(response.getEntity())
32+
// half of the lines are empty newlines, the last line is "data: [DONE]"
33+
.peek(line -> log.info("Handler: {}", line))
34+
.filter(line -> !line.isEmpty() && !"data: [DONE]".equals(line.trim()))
35+
.peek(
36+
line -> {
37+
if (!line.startsWith("data: ")) {
38+
final String msg = "Failed to parse response from the Orchestration service";
39+
parseErrorAndThrow(line, new OrchestrationClientException(msg));
40+
}
41+
})
42+
.map(
43+
line -> {
44+
final String data = line.substring(5); // remove "data: "
45+
try {
46+
return JACKSON.readValue(data, deltaType);
47+
} catch (final IOException e) { // exception message e gets lost
48+
log.error(
49+
"Failed to parse the following response from the Orchestration service: {}",
50+
line);
51+
throw new OrchestrationClientException("Failed to parse delta message: " + line, e);
52+
}
53+
});
54+
}
55+
}

0 commit comments

Comments
 (0)