Skip to content

Commit 40e56b3

Browse files
Orchestration streaming first version
1 parent 42ad4bf commit 40e56b3

File tree

9 files changed

+396
-1
lines changed

9 files changed

+396
-1
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static java.nio.charset.StandardCharsets.UTF_8;
4+
import static java.util.Spliterator.NONNULL;
5+
import static java.util.Spliterator.ORDERED;
6+
7+
import io.vavr.control.Try;
8+
import java.io.BufferedReader;
9+
import java.io.IOException;
10+
import java.io.InputStream;
11+
import java.io.InputStreamReader;
12+
import java.util.Iterator;
13+
import java.util.NoSuchElementException;
14+
import java.util.Spliterators;
15+
import java.util.concurrent.Callable;
16+
import java.util.function.Function;
17+
import java.util.stream.Stream;
18+
import java.util.stream.StreamSupport;
19+
import javax.annotation.Nonnull;
20+
import javax.annotation.Nullable;
21+
import lombok.AccessLevel;
22+
import lombok.RequiredArgsConstructor;
23+
import lombok.extern.slf4j.Slf4j;
24+
import org.apache.hc.core5.http.HttpEntity;
25+
26+
/**
27+
* Internal utility class to convert from a reading handler to {@link Iterable} and {@link Stream}.
28+
*
29+
* <p><strong>Note:</strong> All operations are sequential in nature. Thread safety is not
30+
* guaranteed.
31+
*
32+
* @param <T> Iterated item type.
33+
*/
34+
@Slf4j
35+
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
36+
class IterableStreamConverter<T> implements Iterator<T> {
37+
/** see DEFAULT_CHAR_BUFFER_SIZE in {@link BufferedReader} * */
38+
static final int BUFFER_SIZE = 8192;
39+
40+
/** Read next entry for Stream or {@code null} when no further entry can be read. */
41+
private final Callable<T> readHandler;
42+
43+
/** Close handler to be called when Stream terminated. */
44+
private final Runnable stopHandler;
45+
46+
/** Error handler to be called when Stream is interrupted. */
47+
private final Function<Exception, RuntimeException> errorHandler;
48+
49+
private boolean isDone = false;
50+
private boolean isNextFetched = false;
51+
private T next = null;
52+
53+
@SuppressWarnings("checkstyle:IllegalCatch")
54+
@Override
55+
public boolean hasNext() {
56+
if (isDone) {
57+
return false;
58+
}
59+
if (isNextFetched) {
60+
return true;
61+
}
62+
try {
63+
next = readHandler.call();
64+
isNextFetched = true;
65+
if (next == null) {
66+
isDone = true;
67+
stopHandler.run();
68+
}
69+
} catch (final Exception e) {
70+
isDone = true;
71+
stopHandler.run();
72+
log.debug("Error while reading next element.", e);
73+
throw errorHandler.apply(e);
74+
}
75+
return !isDone;
76+
}
77+
78+
@Override
79+
public T next() {
80+
if (next == null && !hasNext()) {
81+
throw new NoSuchElementException(); // normally not reached with Stream API
82+
}
83+
isNextFetched = false;
84+
return next;
85+
}
86+
87+
/**
88+
* Create a sequential Stream of lines from an HTTP response string (UTF-8). The underlying {@link
89+
* InputStream} is closed, when the resulting Stream is closed (e.g. via try-with-resources) or
90+
* when an exception occurred.
91+
*
92+
* @param entity The HTTP entity object.
93+
* @return A sequential Stream object.
94+
* @throws OrchestrationClientException if the provided HTTP entity object is {@code null} or
95+
* empty.
96+
*/
97+
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed
98+
@Nonnull
99+
static Stream<String> lines(@Nullable final HttpEntity entity)
100+
throws OrchestrationClientException {
101+
if (entity == null) {
102+
throw new OrchestrationClientException("OpenAI response was empty.");
103+
}
104+
105+
final InputStream inputStream;
106+
try {
107+
inputStream = entity.getContent();
108+
} catch (final IOException e) {
109+
throw new OrchestrationClientException("Failed to read response content.", e);
110+
}
111+
112+
final var reader = new BufferedReader(new InputStreamReader(inputStream, UTF_8), BUFFER_SIZE);
113+
final Runnable closeHandler =
114+
() -> Try.run(reader::close).onFailure(e -> log.error("Could not close input stream", e));
115+
final Function<Exception, RuntimeException> errHandler =
116+
e -> new OrchestrationClientException("Parsing response content was interrupted.", e);
117+
118+
final var iterator = new IterableStreamConverter<>(reader::readLine, closeHandler, errHandler);
119+
final var spliterator = Spliterators.spliteratorUnknownSize(iterator, ORDERED | NONNULL);
120+
return StreamSupport.stream(spliterator, /* NOT PARALLEL */ false).onClose(closeHandler);
121+
}
122+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.model.CompletionPostResponse;
4+
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
5+
import java.util.Map;
6+
import javax.annotation.Nonnull;
7+
import javax.annotation.Nullable;
8+
import lombok.val;
9+
10+
/** Orchestration chat completion output delta for streaming. */
11+
public class OrchestrationChatCompletionDelta extends CompletionPostResponse
12+
implements StreamedDelta {
13+
14+
@Nonnull
15+
@Override
16+
// will be fixed once the generated code add a discriminator which will allow this class to extend
17+
// CompletionPostResponseStreaming
18+
@SuppressWarnings("unchecked")
19+
public String getDeltaContent() {
20+
val choices = ((LLMModuleResultSynchronous) getOrchestrationResult()).getChoices();
21+
// Avoid the first delta: "choices":[]
22+
if (!choices.isEmpty()
23+
// Multiple choices are spread out on multiple deltas
24+
// A delta only contains one choice with a variable index
25+
&& choices.get(0).getIndex() == 0) {
26+
27+
final var message = (Map<String, Object>) choices.get(0).getCustomField("delta");
28+
// Avoid the second delta: "choices":[{"delta":{"content":"","role":"assistant"}}]
29+
if (message != null && message.get("content") != null) {
30+
return message.get("content").toString();
31+
}
32+
}
33+
return "";
34+
}
35+
36+
@Nullable
37+
@Override
38+
public String getFinishReason() {
39+
return ((LLMModuleResultSynchronous) getOrchestrationResult())
40+
.getChoices()
41+
.get(0)
42+
.getFinishReason();
43+
}
44+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException;
2626
import java.io.IOException;
2727
import java.util.function.Supplier;
28+
import java.util.stream.Stream;
2829
import javax.annotation.Nonnull;
2930
import lombok.extern.slf4j.Slf4j;
3031
import lombok.val;
3132
import org.apache.hc.client5.http.classic.methods.HttpPost;
3233
import org.apache.hc.core5.http.ContentType;
3334
import org.apache.hc.core5.http.io.entity.StringEntity;
35+
import org.apache.hc.core5.http.message.BasicClassicHttpRequest;
3436

3537
/** Client to execute requests to the orchestration service. */
3638
@Slf4j
@@ -111,6 +113,23 @@ public OrchestrationChatResponse chatCompletion(
111113
return new OrchestrationChatResponse(executeRequest(request));
112114
}
113115

116+
/**
117+
* Generate a completion for the given prompt.
118+
*
119+
* @param prompt a text message.
120+
* @return A stream of message deltas
121+
* @throws OrchestrationClientException if the request fails or if the finish reason is
122+
* content_filter
123+
*/
124+
@Nonnull
125+
public Stream<OrchestrationChatCompletionDelta> streamChatCompletion(
126+
@Nonnull final OrchestrationPrompt prompt, @Nonnull final OrchestrationModuleConfig config)
127+
throws OrchestrationClientException {
128+
129+
val request = toCompletionPostRequest(prompt, config);
130+
return streamChatCompletionDeltas(request);
131+
}
132+
114133
/**
115134
* Serializes the given request, executes it and deserializes the response.
116135
*
@@ -211,4 +230,53 @@ CompletionPostResponse executeRequest(@Nonnull final String request) {
211230
throw new OrchestrationClientException("Failed to execute request", e);
212231
}
213232
}
233+
234+
/**
235+
* Generate a completion for the given prompt.
236+
*
237+
* @param request the prompt, including messages and other parameters.
238+
* @return A stream of chat completion delta elements.
239+
* @throws OrchestrationClientException if the request fails
240+
*/
241+
@Nonnull
242+
public Stream<OrchestrationChatCompletionDelta> streamChatCompletionDeltas(
243+
@Nonnull final CompletionPostRequest request) throws OrchestrationClientException {
244+
request.getOrchestrationConfig().setStream(true);
245+
return executeStream("/completion", request, OrchestrationChatCompletionDelta.class);
246+
}
247+
248+
@Nonnull
249+
private <D extends StreamedDelta> Stream<D> executeStream(
250+
@Nonnull final String path,
251+
@Nonnull final Object payload,
252+
@Nonnull final Class<D> deltaType) {
253+
final var request = new HttpPost(path);
254+
serializeAndSetHttpEntity(request, payload);
255+
return streamRequest(request, deltaType);
256+
}
257+
258+
private static void serializeAndSetHttpEntity(
259+
@Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) {
260+
try {
261+
final var json = JACKSON.writeValueAsString(payload);
262+
request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON));
263+
} catch (final JsonProcessingException e) {
264+
throw new OrchestrationClientException("Failed to serialize request parameters", e);
265+
}
266+
}
267+
268+
@Nonnull
269+
private <D extends StreamedDelta> Stream<D> streamRequest(
270+
final BasicClassicHttpRequest request, @Nonnull final Class<D> deltaType) {
271+
val postRequest = new HttpPost("/completion");
272+
try {
273+
val destination = destinationSupplier.get();
274+
log.debug("Using destination {} to connect to orchestration service", destination);
275+
val client = ApacheHttpClient5Accessor.getHttpClient(destination);
276+
return new OrchestrationStreamingHandler<>(deltaType)
277+
.handleResponse(client.executeOpen(null, request, null));
278+
} catch (final IOException e) {
279+
throw new OrchestrationClientException("Request to the Orchestration service failed", e);
280+
}
281+
}
214282
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import static com.sap.ai.sdk.orchestration.OrchestrationClient.JACKSON;
4+
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.buildExceptionAndThrow;
5+
import static com.sap.ai.sdk.orchestration.OrchestrationResponseHandler.parseErrorAndThrow;
6+
7+
import java.io.IOException;
8+
import java.util.stream.Stream;
9+
import javax.annotation.Nonnull;
10+
import lombok.RequiredArgsConstructor;
11+
import lombok.extern.slf4j.Slf4j;
12+
import org.apache.hc.core5.http.ClassicHttpResponse;
13+
14+
@Slf4j
15+
@RequiredArgsConstructor
16+
class OrchestrationStreamingHandler<D extends StreamedDelta> {
17+
18+
@Nonnull private final Class<D> deltaType;
19+
20+
/**
21+
* @param response The response to process
22+
* @return A {@link Stream} of a model class instantiated from the response
23+
*/
24+
@SuppressWarnings("PMD.CloseResource") // Stream is closed automatically when consumed
25+
@Nonnull
26+
Stream<D> handleResponse(@Nonnull final ClassicHttpResponse response)
27+
throws OrchestrationClientException {
28+
if (response.getCode() >= 300) {
29+
buildExceptionAndThrow(response);
30+
}
31+
return IterableStreamConverter.lines(response.getEntity())
32+
// half of the lines are empty newlines, the last line is "data: [DONE]"
33+
.peek(line -> log.info("Handler: {}", line))
34+
.filter(line -> !line.isEmpty() && !"data: [DONE]".equals(line.trim()))
35+
.peek(
36+
line -> {
37+
if (!line.startsWith("data: ")) {
38+
final String msg = "Failed to parse response from OpenAI model";
39+
parseErrorAndThrow(line, new OrchestrationClientException(msg));
40+
}
41+
})
42+
.map(
43+
line -> {
44+
final String data = line.substring(5); // remove "data: "
45+
try {
46+
return JACKSON.readValue(data, deltaType);
47+
} catch (final IOException e) { // exception message e gets lost
48+
log.error("Failed to parse the following response from OpenAI model: {}", line);
49+
throw new OrchestrationClientException("Failed to parse delta message: " + line, e);
50+
}
51+
});
52+
}
53+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import javax.annotation.Nonnull;
4+
import javax.annotation.Nullable;
5+
6+
/**
7+
* Interface for streamed delta classes.
8+
*
9+
* <p>This interface defines a method to retrieve the content from a delta, which is a chunk in a
10+
* stream of data. Implementations of this interface should provide the logic to extract the
11+
* relevant content from the delta.
12+
*/
13+
public interface StreamedDelta {
14+
15+
/**
16+
* Get the message content from the delta.
17+
*
18+
* <p>Note: If there are multiple choices only the first one is returned
19+
*
20+
* <p>Note: Some deltas do not contain any content
21+
*
22+
* @return the message content or empty string.
23+
*/
24+
@Nonnull
25+
String getDeltaContent();
26+
27+
/**
28+
* Reason for finish. The possible values are:
29+
*
30+
* <p>{@code stop}: API returned complete message, or a message terminated by one of the stop
31+
* sequences provided via the stop parameter
32+
*
33+
* <p>{@code length}: Incomplete model output due to max_tokens parameter or token limit
34+
*
35+
* <p>{@code function_call}: The model decided to call a function
36+
*
37+
* <p>{@code content_filter}: Omitted content due to a flag from our content filters
38+
*
39+
* <p>{@code null}: API response still in progress or incomplete
40+
*/
41+
@Nullable
42+
String getFinishReason();
43+
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
123123
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
124124
}
125125

126-
private static void send(
126+
static void send(
127127
@Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) {
128128
try {
129129
emitter.send(chunk);

0 commit comments

Comments
 (0)