|
| 1 | +package com.sap.ai.sdk.foundationmodels.openai.spring; |
| 2 | + |
| 3 | +import static org.springframework.ai.model.tool.ToolCallingChatOptions.isInternalToolExecutionEnabled; |
| 4 | + |
| 5 | +import com.fasterxml.jackson.core.JsonProcessingException; |
| 6 | +import com.fasterxml.jackson.core.type.TypeReference; |
| 7 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 8 | +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta; |
| 9 | +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest; |
| 10 | +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse; |
| 11 | +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; |
| 12 | +import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; |
| 13 | +import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall; |
| 14 | +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall; |
| 15 | +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; |
| 16 | +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner; |
| 17 | +import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; |
| 18 | +import io.vavr.control.Option; |
| 19 | +import java.math.BigDecimal; |
| 20 | +import java.util.ArrayList; |
| 21 | +import java.util.List; |
| 22 | +import java.util.Map; |
| 23 | +import java.util.function.Function; |
| 24 | +import javax.annotation.Nonnull; |
| 25 | +import lombok.RequiredArgsConstructor; |
| 26 | +import lombok.extern.slf4j.Slf4j; |
| 27 | +import lombok.val; |
| 28 | +import org.springframework.ai.chat.messages.AssistantMessage; |
| 29 | +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; |
| 30 | +import org.springframework.ai.chat.messages.Message; |
| 31 | +import org.springframework.ai.chat.messages.ToolResponseMessage; |
| 32 | +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; |
| 33 | +import org.springframework.ai.chat.model.ChatModel; |
| 34 | +import org.springframework.ai.chat.model.ChatResponse; |
| 35 | +import org.springframework.ai.chat.model.Generation; |
| 36 | +import org.springframework.ai.chat.prompt.ChatOptions; |
| 37 | +import org.springframework.ai.chat.prompt.Prompt; |
| 38 | +import org.springframework.ai.model.tool.DefaultToolCallingManager; |
| 39 | +import org.springframework.ai.model.tool.ToolCallingChatOptions; |
| 40 | +import reactor.core.publisher.Flux; |
| 41 | + |
| 42 | +/** |
| 43 | + * OpenAI Chat Model implementation that interacts with the OpenAI API to generate chat completions. |
| 44 | + */ |
| 45 | +@Slf4j |
| 46 | +@RequiredArgsConstructor |
| 47 | +public class OpenAiChatModel implements ChatModel { |
| 48 | + |
| 49 | + private final OpenAiClient client; |
| 50 | + |
| 51 | + @Nonnull |
| 52 | + private final DefaultToolCallingManager toolCallingManager = |
| 53 | + DefaultToolCallingManager.builder().build(); |
| 54 | + |
| 55 | + @Override |
| 56 | + @Nonnull |
| 57 | + public ChatResponse call(@Nonnull final Prompt prompt) { |
| 58 | + val options = prompt.getOptions(); |
| 59 | + var request = new OpenAiChatCompletionRequest(extractMessages(prompt)); |
| 60 | + |
| 61 | + if (options != null) { |
| 62 | + request = extractOptions(request, options); |
| 63 | + } |
| 64 | + if ((options instanceof ToolCallingChatOptions toolOptions)) { |
| 65 | + request = request.withTools(extractTools(toolOptions)); |
| 66 | + } |
| 67 | + |
| 68 | + val result = client.chatCompletion(request); |
| 69 | + val response = new ChatResponse(toGenerations(result)); |
| 70 | + |
| 71 | + if (options != null && isInternalToolExecutionEnabled(options) && response.hasToolCalls()) { |
| 72 | + val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response); |
| 73 | + // Send the tool execution result back to the model. |
| 74 | + return call(new Prompt(toolExecutionResult.conversationHistory(), options)); |
| 75 | + } |
| 76 | + return response; |
| 77 | + } |
| 78 | + |
| 79 | + @Override |
| 80 | + @Nonnull |
| 81 | + public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) { |
| 82 | + val options = prompt.getOptions(); |
| 83 | + var request = new OpenAiChatCompletionRequest(extractMessages(prompt)); |
| 84 | + |
| 85 | + if (options != null) { |
| 86 | + request = extractOptions(request, options); |
| 87 | + } |
| 88 | + if ((options instanceof ToolCallingChatOptions toolOptions)) { |
| 89 | + request = request.withTools(extractTools(toolOptions)); |
| 90 | + } |
| 91 | + |
| 92 | + val stream = client.streamChatCompletionDeltas(request); |
| 93 | + final Flux<OpenAiChatCompletionDelta> flux = |
| 94 | + Flux.generate( |
| 95 | + stream::iterator, |
| 96 | + (iterator, sink) -> { |
| 97 | + if (iterator.hasNext()) { |
| 98 | + sink.next(iterator.next()); |
| 99 | + } else { |
| 100 | + sink.complete(); |
| 101 | + } |
| 102 | + return iterator; |
| 103 | + }); |
| 104 | + return flux.map( |
| 105 | + delta -> { |
| 106 | + val assistantMessage = new AssistantMessage(delta.getDeltaContent(), Map.of()); |
| 107 | + val metadata = |
| 108 | + ChatGenerationMetadata.builder().finishReason(delta.getFinishReason()).build(); |
| 109 | + return new ChatResponse(List.of(new Generation(assistantMessage, metadata))); |
| 110 | + }); |
| 111 | + } |
| 112 | + |
| 113 | + private static List<OpenAiMessage> extractMessages(final Prompt prompt) { |
| 114 | + final List<OpenAiMessage> result = new ArrayList<>(); |
| 115 | + for (final Message message : prompt.getInstructions()) { |
| 116 | + switch (message.getMessageType()) { |
| 117 | + case USER -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.user(t))); |
| 118 | + case SYSTEM -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.system(t))); |
| 119 | + case ASSISTANT -> addAssistantMessage(result, (AssistantMessage) message); |
| 120 | + case TOOL -> addToolMessages(result, (ToolResponseMessage) message); |
| 121 | + } |
| 122 | + } |
| 123 | + return result; |
| 124 | + } |
| 125 | + |
| 126 | + private static void addAssistantMessage( |
| 127 | + final List<OpenAiMessage> result, final AssistantMessage message) { |
| 128 | + if (message.getText() != null) { |
| 129 | + result.add(OpenAiMessage.assistant(message.getText())); |
| 130 | + return; |
| 131 | + } |
| 132 | + final Function<ToolCall, OpenAiToolCall> callTranslate = |
| 133 | + toolCall -> OpenAiToolCall.function(toolCall.id(), toolCall.name(), toolCall.arguments()); |
| 134 | + val calls = message.getToolCalls().stream().map(callTranslate).toList(); |
| 135 | + result.add(OpenAiMessage.assistant(calls)); |
| 136 | + } |
| 137 | + |
| 138 | + private static void addToolMessages( |
| 139 | + final List<OpenAiMessage> result, final ToolResponseMessage message) { |
| 140 | + for (final ToolResponseMessage.ToolResponse response : message.getResponses()) { |
| 141 | + result.add(OpenAiMessage.tool(response.responseData(), response.id())); |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + @Nonnull |
| 146 | + private static List<Generation> toGenerations( |
| 147 | + @Nonnull final OpenAiChatCompletionResponse result) { |
| 148 | + return result.getOriginalResponse().getChoices().stream() |
| 149 | + .map(OpenAiChatModel::toGeneration) |
| 150 | + .toList(); |
| 151 | + } |
| 152 | + |
| 153 | + @Nonnull |
| 154 | + private static Generation toGeneration( |
| 155 | + @Nonnull final CreateChatCompletionResponseChoicesInner choice) { |
| 156 | + val metadata = |
| 157 | + ChatGenerationMetadata.builder().finishReason(choice.getFinishReason().getValue()); |
| 158 | + metadata.metadata("index", choice.getIndex()); |
| 159 | + if (choice.getLogprobs() != null && !choice.getLogprobs().getContent().isEmpty()) { |
| 160 | + metadata.metadata("logprobs", choice.getLogprobs().getContent()); |
| 161 | + } |
| 162 | + val message = choice.getMessage(); |
| 163 | + val calls = new ArrayList<ToolCall>(); |
| 164 | + if (message.getToolCalls() != null) { |
| 165 | + for (final ChatCompletionMessageToolCall c : message.getToolCalls()) { |
| 166 | + val fnc = c.getFunction(); |
| 167 | + calls.add( |
| 168 | + new ToolCall(c.getId(), c.getType().getValue(), fnc.getName(), fnc.getArguments())); |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + val assistantMessage = new AssistantMessage(message.getContent(), Map.of(), calls); |
| 173 | + return new Generation(assistantMessage, metadata.build()); |
| 174 | + } |
| 175 | + |
| 176 | + /** |
| 177 | + * Adds options to the request. |
| 178 | + * |
| 179 | + * @param request the request to modify |
| 180 | + * @param options the options to extract |
| 181 | + * @return the modified request with options applied |
| 182 | + */ |
| 183 | + @Nonnull |
| 184 | + protected static OpenAiChatCompletionRequest extractOptions( |
| 185 | + @Nonnull OpenAiChatCompletionRequest request, @Nonnull final ChatOptions options) { |
| 186 | + request = request.withStop(options.getStopSequences()).withMaxTokens(options.getMaxTokens()); |
| 187 | + if (options.getTemperature() != null) { |
| 188 | + request = request.withTemperature(BigDecimal.valueOf(options.getTemperature())); |
| 189 | + } |
| 190 | + if (options.getTopP() != null) { |
| 191 | + request = request.withTopP(BigDecimal.valueOf(options.getTopP())); |
| 192 | + } |
| 193 | + if (options.getPresencePenalty() != null) { |
| 194 | + request = request.withPresencePenalty(BigDecimal.valueOf(options.getPresencePenalty())); |
| 195 | + } |
| 196 | + if (options.getFrequencyPenalty() != null) { |
| 197 | + request = request.withFrequencyPenalty(BigDecimal.valueOf(options.getFrequencyPenalty())); |
| 198 | + } |
| 199 | + return request; |
| 200 | + } |
| 201 | + |
| 202 | + private static List<ChatCompletionTool> extractTools(final ToolCallingChatOptions options) { |
| 203 | + val tools = new ArrayList<ChatCompletionTool>(); |
| 204 | + for (val toolCallback : options.getToolCallbacks()) { |
| 205 | + val toolDefinition = toolCallback.getToolDefinition(); |
| 206 | + try { |
| 207 | + final Map<String, Object> params = |
| 208 | + new ObjectMapper().readValue(toolDefinition.inputSchema(), new TypeReference<>() {}); |
| 209 | + val toolType = ChatCompletionTool.TypeEnum.FUNCTION; |
| 210 | + val toolFunction = |
| 211 | + new FunctionObject() |
| 212 | + .name(toolDefinition.name()) |
| 213 | + .description(toolDefinition.description()) |
| 214 | + .parameters(params); |
| 215 | + val tool = new ChatCompletionTool().type(toolType).function(toolFunction); |
| 216 | + tools.add(tool); |
| 217 | + } catch (JsonProcessingException e) { |
| 218 | + log.warn("Failed to add tool to the chat request: {}", e.getMessage()); |
| 219 | + } |
| 220 | + } |
| 221 | + return tools; |
| 222 | + } |
| 223 | +} |
0 commit comments