|
13 | 13 | import java.util.List; |
14 | 14 | import java.util.Map; |
15 | 15 | import java.util.function.Function; |
| 16 | +import java.util.stream.Stream; |
16 | 17 | import javax.annotation.Nonnull; |
17 | 18 | import lombok.RequiredArgsConstructor; |
18 | 19 | import lombok.extern.slf4j.Slf4j; |
@@ -65,29 +66,35 @@ public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) { |
65 | 66 |
|
66 | 67 | val orchestrationPrompt = toOrchestrationPrompt(prompt); |
67 | 68 | val request = toCompletionPostRequest(orchestrationPrompt, options.getConfig()); |
68 | | - val stream = |
69 | | - client |
70 | | - .streamChatCompletionDeltas(request) |
71 | | - .peek(OrchestrationChatModel::throwOnContentFilter) |
72 | | - .map(OrchestrationSpringChatDelta::new); |
73 | | - return Flux.generate( |
74 | | - stream::iterator, |
75 | | - (iterator, sink) -> { |
76 | | - if (iterator.hasNext()) { |
77 | | - sink.next(iterator.next()); |
78 | | - } else { |
79 | | - sink.complete(); |
80 | | - } |
81 | | - return iterator; |
| 69 | + val stream = client.streamChatCompletionDeltas(request); |
| 70 | + |
| 71 | + final Flux<OrchestrationChatCompletionDelta> flux = |
| 72 | + Flux.generate( |
| 73 | + stream::iterator, |
| 74 | + (iterator, sink) -> { |
| 75 | + if (iterator.hasNext()) { |
| 76 | + sink.next(iterator.next()); |
| 77 | + } else { |
| 78 | + sink.complete(); |
| 79 | + } |
| 80 | + return iterator; |
| 81 | + }); |
| 82 | + return flux.map( |
| 83 | + delta -> { |
| 84 | + throwOnContentFilter(stream, delta); |
| 85 | + return new OrchestrationSpringChatDelta(delta); |
82 | 86 | }); |
83 | 87 | } |
84 | 88 | throw new IllegalArgumentException( |
85 | 89 | "Please add OrchestrationChatOptions to the Prompt: new Prompt(\"message\", new OrchestrationChatOptions(config))"); |
86 | 90 | } |
87 | 91 |
|
88 | | - private static void throwOnContentFilter(@Nonnull final OrchestrationChatCompletionDelta delta) { |
| 92 | + private static void throwOnContentFilter( |
| 93 | + @Nonnull final Stream<OrchestrationChatCompletionDelta> stream, |
| 94 | + @Nonnull final OrchestrationChatCompletionDelta delta) { |
89 | 95 | final String finishReason = delta.getFinishReason(); |
90 | 96 | if (finishReason != null && finishReason.equals("content_filter")) { |
| 97 | + stream.close(); |
91 | 98 | throw new OrchestrationClientException("Content filter filtered the output."); |
92 | 99 | } |
93 | 100 | } |
|
0 commit comments