|
1 | 1 | package com.theokanning.openai.service; |
2 | 2 |
|
3 | 3 | import com.fasterxml.jackson.annotation.JsonInclude; |
| 4 | +import com.fasterxml.jackson.core.JsonProcessingException; |
4 | 5 | import com.fasterxml.jackson.core.type.TypeReference; |
5 | 6 | import com.fasterxml.jackson.databind.DeserializationFeature; |
6 | 7 | import com.fasterxml.jackson.databind.ObjectMapper; |
|
72 | 73 | import java.util.*; |
73 | 74 | import java.util.concurrent.ExecutorService; |
74 | 75 | import java.util.concurrent.TimeUnit; |
| 76 | +import java.util.function.BiConsumer; |
| 77 | +import java.util.function.Supplier; |
75 | 78 |
|
76 | 79 | public class OpenAiService { |
77 | 80 |
|
@@ -190,7 +193,17 @@ public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) |
190 | 193 |
|
191 | 194 | public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) { |
192 | 195 | request.setStream(true); |
193 | | - return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class); |
| 196 | + return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class, new BiConsumer<ChatCompletionChunk, SSE>() { |
| 197 | + @Override |
| 198 | + public void accept(ChatCompletionChunk chatCompletionChunk, SSE sse) { |
| 199 | + chatCompletionChunk.setSource(sse.getData()); |
| 200 | + } |
| 201 | + }, new Supplier<ChatCompletionChunk>() { |
| 202 | + @Override |
| 203 | + public ChatCompletionChunk get() { |
| 204 | + return new ChatCompletionChunk(); |
| 205 | + } |
| 206 | + }); |
194 | 207 | } |
195 | 208 |
|
196 | 209 |
|
@@ -692,6 +705,31 @@ public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) { |
692 | 705 | return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl)); |
693 | 706 | } |
694 | 707 |
|
| 708 | + /** |
| 709 | + * Calls the Open AI api and returns a Flowable of type T for streaming |
| 710 | + * omitting the last message. |
| 711 | + * @param apiCall The api call |
| 712 | + * @param cl Class of type T to return |
| 713 | + * @param consumer After the instance creation is complete |
| 714 | + * @param newInstance If the serialization fails, call this interface to get an instance |
| 715 | + */ |
| 716 | + public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl, BiConsumer<T, SSE> consumer, |
| 717 | + Supplier<T> newInstance) { |
| 718 | + return stream(apiCall, true).map(sse -> { |
| 719 | + try { |
| 720 | + T t = mapper.readValue(sse.getData(), cl); |
| 721 | + if (Objects.nonNull(consumer)) { |
| 722 | + consumer.accept(t, sse); |
| 723 | + } |
| 724 | + return t; |
| 725 | + } catch (JsonProcessingException e) { |
| 726 | + T t = newInstance.get(); |
| 727 | + consumer.accept(t, sse); |
| 728 | + return t; |
| 729 | + } |
| 730 | + }); |
| 731 | + } |
| 732 | + |
695 | 733 | /** |
696 | 734 | * Shuts down the OkHttp ExecutorService. |
697 | 735 | * The default behaviour of OkHttp's ExecutorService (ConnectionPool) |
@@ -758,6 +796,26 @@ public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatComp |
758 | 796 | }); |
759 | 797 | } |
760 | 798 |
|
| 799 | + public Flowable<ChatMessageAccumulatorWrapper> mapStreamToAccumulatorWrapper(Flowable<ChatCompletionChunk> flowable) { |
| 800 | + ChatFunctionCall functionCall = new ChatFunctionCall(null, null); |
| 801 | + AssistantMessage accumulatedMessage = new AssistantMessage(); |
| 802 | + return flowable.map(chunk -> { |
| 803 | + List<ChatCompletionChoice> choices = chunk.getChoices(); |
| 804 | + AssistantMessage messageChunk = null; |
| 805 | + if (null != choices && !choices.isEmpty()) { |
| 806 | + ChatCompletionChoice firstChoice = choices.get(0); |
| 807 | + messageChunk = firstChoice.getMessage(); |
| 808 | + appendContent(messageChunk, accumulatedMessage); |
| 809 | + processFunctionCall(messageChunk, functionCall, accumulatedMessage); |
| 810 | + processToolCalls(messageChunk, accumulatedMessage); |
| 811 | + if (firstChoice.getFinishReason() != null) { |
| 812 | + handleFinishReason(firstChoice.getFinishReason(), functionCall, accumulatedMessage); |
| 813 | + } |
| 814 | + } |
| 815 | + ChatMessageAccumulator chatMessageAccumulator = new ChatMessageAccumulator(messageChunk, accumulatedMessage, chunk.getUsage()); |
| 816 | + return new ChatMessageAccumulatorWrapper(chatMessageAccumulator, chunk); |
| 817 | + }); |
| 818 | + } |
761 | 819 |
|
762 | 820 | /** |
763 | 821 | * 处理消息块中的函数调用。 |
|
0 commit comments