Skip to content

Commit 83d7807

Browse files
Fixed
1 parent dfe1246 commit 83d7807

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.List;
1414
import java.util.Map;
1515
import java.util.function.Function;
16+
import java.util.stream.Stream;
1617
import javax.annotation.Nonnull;
1718
import lombok.RequiredArgsConstructor;
1819
import lombok.extern.slf4j.Slf4j;
@@ -65,29 +66,35 @@ public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
6566

6667
val orchestrationPrompt = toOrchestrationPrompt(prompt);
6768
val request = toCompletionPostRequest(orchestrationPrompt, options.getConfig());
68-
val stream =
69-
client
70-
.streamChatCompletionDeltas(request)
71-
.peek(OrchestrationChatModel::throwOnContentFilter)
72-
.map(OrchestrationSpringChatDelta::new);
73-
return Flux.generate(
74-
stream::iterator,
75-
(iterator, sink) -> {
76-
if (iterator.hasNext()) {
77-
sink.next(iterator.next());
78-
} else {
79-
sink.complete();
80-
}
81-
return iterator;
69+
val stream = client.streamChatCompletionDeltas(request);
70+
71+
final Flux<OrchestrationChatCompletionDelta> flux =
72+
Flux.generate(
73+
stream::iterator,
74+
(iterator, sink) -> {
75+
if (iterator.hasNext()) {
76+
sink.next(iterator.next());
77+
} else {
78+
sink.complete();
79+
}
80+
return iterator;
81+
});
82+
return flux.map(
83+
delta -> {
84+
throwOnContentFilter(stream, delta);
85+
return new OrchestrationSpringChatDelta(delta);
8286
});
8387
}
8488
throw new IllegalArgumentException(
8589
"Please add OrchestrationChatOptions to the Prompt: new Prompt(\"message\", new OrchestrationChatOptions(config))");
8690
}
8791

88-
private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) {
92+
private static void throwOnContentFilter(
93+
@Nonnull final Stream<OrchestrationChatCompletionDelta> stream,
94+
@Nonnull final OrchestrationChatCompletionDelta delta) {
8995
final String finishReason = delta.getFinishReason();
9096
if (finishReason != null && finishReason.equals("content_filter")) {
97+
stream.close();
9198
throw new OrchestrationClientException("Content filter filtered the output.");
9299
}
93100
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModelTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ void testThrowsOnMissingLlmConfig() {
103103
.hasMessageContaining("LLM config is required");
104104
}
105105

106-
@Disabled
107106
@Test
108107
void streamChatCompletionOutputFilterErrorHandling() throws IOException {
109108
try (var inputStream = spy(fileLoader.apply("streamChatCompletionOutputFilter.txt"))) {

0 commit comments

Comments
 (0)