Skip to content

Commit f9a719e

Browse files
Added more options and metadata
1 parent 9b2b5bf commit f9a719e

File tree

3 files changed

+106
-62
lines changed

3 files changed

+106
-62
lines changed

foundation-models/openai/pom.xml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
</scm>
3939
<properties>
4040
<project.rootdir>${project.basedir}/../../</project.rootdir>
41-
<coverage.complexity>83%</coverage.complexity>
42-
<coverage.line>92%</coverage.line>
43-
<coverage.instruction>90%</coverage.instruction>
44-
<coverage.branch>81%</coverage.branch>
41+
<coverage.complexity>81%</coverage.complexity>
42+
<coverage.line>91%</coverage.line>
43+
<coverage.instruction>89%</coverage.instruction>
44+
<coverage.branch>79%</coverage.branch>
4545
<coverage.method>90%</coverage.method>
4646
<coverage.class>92%</coverage.class>
4747
</properties>
Lines changed: 96 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package com.sap.ai.sdk.foundationmodels.openai.spring;
22

3-
import static org.springframework.ai.model.tool.ToolCallingChatOptions.isInternalToolExecutionEnabled;
4-
53
import com.fasterxml.jackson.core.JsonProcessingException;
64
import com.fasterxml.jackson.core.type.TypeReference;
75
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -12,29 +10,35 @@
1210
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1311
import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall;
1412
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall;
15-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessage;
1613
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
14+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner;
1715
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
1816
import io.vavr.control.Option;
19-
import java.util.ArrayList;
20-
import java.util.List;
21-
import java.util.Map;
22-
import java.util.function.Function;
23-
import javax.annotation.Nonnull;
2417
import lombok.RequiredArgsConstructor;
2518
import lombok.val;
2619
import org.springframework.ai.chat.messages.AssistantMessage;
2720
import org.springframework.ai.chat.messages.AssistantMessage.ToolCall;
2821
import org.springframework.ai.chat.messages.Message;
2922
import org.springframework.ai.chat.messages.ToolResponseMessage;
23+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3024
import org.springframework.ai.chat.model.ChatModel;
3125
import org.springframework.ai.chat.model.ChatResponse;
3226
import org.springframework.ai.chat.model.Generation;
27+
import org.springframework.ai.chat.prompt.ChatOptions;
3328
import org.springframework.ai.chat.prompt.Prompt;
34-
import org.springframework.ai.model.tool.DefaultToolCallingChatOptions;
3529
import org.springframework.ai.model.tool.DefaultToolCallingManager;
30+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3631
import reactor.core.publisher.Flux;
3732

33+
import javax.annotation.Nonnull;
34+
import java.math.BigDecimal;
35+
import java.util.ArrayList;
36+
import java.util.List;
37+
import java.util.Map;
38+
import java.util.function.Function;
39+
40+
import static org.springframework.ai.model.tool.ToolCallingChatOptions.isInternalToolExecutionEnabled;
41+
3842
/**
3943
* OpenAI Chat Model implementation that interacts with the OpenAI API to generate chat completions.
4044
*/
@@ -50,34 +54,40 @@ public class OpenAiChatModel implements ChatModel {
5054
@Override
5155
@Nonnull
5256
public ChatResponse call(@Nonnull final Prompt prompt) {
53-
val openAiRequest = toOpenAiRequest(prompt);
54-
var request = new OpenAiChatCompletionRequest(openAiRequest);
57+
val options = prompt.getOptions();
58+
var request = new OpenAiChatCompletionRequest(extractMessages(prompt));
5559

56-
if ((prompt.getOptions() instanceof DefaultToolCallingChatOptions options)) {
57-
request = request.withTools(extractTools(options));
60+
if (options != null) {
61+
request = extractOptions(request, options);
62+
}
63+
if ((options instanceof ToolCallingChatOptions toolOptions)) {
64+
request = request.withTools(extractTools(toolOptions));
5865
}
5966

6067
val result = client.chatCompletion(request);
6168
val response = new ChatResponse(toGenerations(result));
6269

63-
if (prompt.getOptions() != null
64-
&& isInternalToolExecutionEnabled(prompt.getOptions())
65-
&& response.hasToolCalls()) {
70+
if (options != null && isInternalToolExecutionEnabled(options) && response.hasToolCalls()) {
6671
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response);
6772
// Send the tool execution result back to the model.
68-
return call(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
73+
return call(new Prompt(toolExecutionResult.conversationHistory(), options));
6974
}
7075
return response;
7176
}
7277

7378
@Override
7479
@Nonnull
7580
public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
76-
val openAiRequest = toOpenAiRequest(prompt);
77-
var request = new OpenAiChatCompletionRequest(openAiRequest);
78-
if ((prompt.getOptions() instanceof DefaultToolCallingChatOptions options)) {
79-
request = request.withTools(extractTools(options));
81+
val options = prompt.getOptions();
82+
var request = new OpenAiChatCompletionRequest(extractMessages(prompt));
83+
84+
if (options != null) {
85+
request = extractOptions(request, options);
86+
}
87+
if ((options instanceof ToolCallingChatOptions toolOptions)) {
88+
request = request.withTools(extractTools(toolOptions));
8089
}
90+
8191
val stream = client.streamChatCompletionDeltas(request);
8292
final Flux<OpenAiChatCompletionDelta> flux =
8393
Flux.generate(
@@ -90,37 +100,16 @@ public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
90100
}
91101
return iterator;
92102
});
93-
return flux.map(OpenAiChatModel::toChatResponse);
103+
return flux.map(
104+
delta -> {
105+
val assistantMessage = new AssistantMessage(delta.getDeltaContent(), Map.of());
106+
val metadata =
107+
ChatGenerationMetadata.builder().finishReason(delta.getFinishReason()).build();
108+
return new ChatResponse(List.of(new Generation(assistantMessage, metadata)));
109+
});
94110
}
95111

96-
private List<ChatCompletionTool> extractTools(final DefaultToolCallingChatOptions options) {
97-
val tools = new ArrayList<ChatCompletionTool>();
98-
for (val toolCallback : options.getToolCallbacks()) {
99-
val toolDefinition = toolCallback.getToolDefinition();
100-
try {
101-
final Map<String, Object> params =
102-
new ObjectMapper().readValue(toolDefinition.inputSchema(), new TypeReference<>() {});
103-
val tool =
104-
new ChatCompletionTool()
105-
.type(ChatCompletionTool.TypeEnum.FUNCTION)
106-
.function(
107-
new FunctionObject()
108-
.name(toolDefinition.name())
109-
.description(toolDefinition.description())
110-
.parameters(params));
111-
tools.add(tool);
112-
} catch (JsonProcessingException ignored) {
113-
}
114-
}
115-
return tools;
116-
}
117-
118-
private static ChatResponse toChatResponse(final OpenAiChatCompletionDelta delta) {
119-
val assistantMessage = new AssistantMessage(delta.getDeltaContent(), Map.of());
120-
return new ChatResponse(List.of(new Generation(assistantMessage)));
121-
}
122-
123-
private List<OpenAiMessage> toOpenAiRequest(final Prompt prompt) {
112+
private List<OpenAiMessage> extractMessages(final Prompt prompt) {
124113
final List<OpenAiMessage> result = new ArrayList<>();
125114
for (final Message message : prompt.getInstructions()) {
126115
switch (message.getMessageType()) {
@@ -153,24 +142,73 @@ private static void addToolMessages(
153142
}
154143

155144
@Nonnull
156-
static List<Generation> toGenerations(@Nonnull final OpenAiChatCompletionResponse result) {
145+
private static List<Generation> toGenerations(
146+
@Nonnull final OpenAiChatCompletionResponse result) {
157147
return result.getOriginalResponse().getChoices().stream()
158-
.map(message -> toGeneration(message.getMessage()))
148+
.map(OpenAiChatModel::toGeneration)
159149
.toList();
160150
}
161151

162152
@Nonnull
163-
static Generation toGeneration(@Nonnull final ChatCompletionResponseMessage choice) {
164-
// no metadata for now
153+
private static Generation toGeneration(
154+
@Nonnull final CreateChatCompletionResponseChoicesInner choice) {
155+
val metadata =
156+
ChatGenerationMetadata.builder().finishReason(choice.getFinishReason().getValue());
157+
metadata.metadata("index", choice.getIndex());
158+
if (choice.getLogprobs() != null && !choice.getLogprobs().getContent().isEmpty()) {
159+
metadata.metadata("logprobs", choice.getLogprobs().getContent());
160+
}
161+
val message = choice.getMessage();
165162
val calls = new ArrayList<ToolCall>();
166-
if (choice.getToolCalls() != null) {
167-
for (final ChatCompletionMessageToolCall c : choice.getToolCalls()) {
163+
if (message.getToolCalls() != null) {
164+
for (final ChatCompletionMessageToolCall c : message.getToolCalls()) {
168165
val fnc = c.getFunction();
169166
calls.add(
170167
new ToolCall(c.getId(), c.getType().getValue(), fnc.getName(), fnc.getArguments()));
171168
}
172169
}
173-
val message = new AssistantMessage(choice.getContent(), Map.of(), calls);
174-
return new Generation(message);
170+
171+
val assistantMessage = new AssistantMessage(message.getContent(), Map.of(), calls);
172+
return new Generation(assistantMessage, metadata.build());
173+
}
174+
175+
private OpenAiChatCompletionRequest extractOptions(
176+
@Nonnull OpenAiChatCompletionRequest request, @Nonnull final ChatOptions options) {
177+
request = request.withStop(options.getStopSequences()).withMaxTokens(options.getMaxTokens());
178+
if (options.getTemperature() != null) {
179+
request = request.withTemperature(BigDecimal.valueOf(options.getTemperature()));
180+
}
181+
if (options.getTopP() != null) {
182+
request = request.withTopP(BigDecimal.valueOf(options.getTopP()));
183+
}
184+
if (options.getPresencePenalty() != null) {
185+
request = request.withPresencePenalty(BigDecimal.valueOf(options.getPresencePenalty()));
186+
}
187+
if (options.getFrequencyPenalty() != null) {
188+
request = request.withFrequencyPenalty(BigDecimal.valueOf(options.getFrequencyPenalty()));
189+
}
190+
return request;
191+
}
192+
193+
private List<ChatCompletionTool> extractTools(final ToolCallingChatOptions options) {
194+
val tools = new ArrayList<ChatCompletionTool>();
195+
for (val toolCallback : options.getToolCallbacks()) {
196+
val toolDefinition = toolCallback.getToolDefinition();
197+
try {
198+
final Map<String, Object> params =
199+
new ObjectMapper().readValue(toolDefinition.inputSchema(), new TypeReference<>() {});
200+
val tool =
201+
new ChatCompletionTool()
202+
.type(ChatCompletionTool.TypeEnum.FUNCTION)
203+
.function(
204+
new FunctionObject()
205+
.name(toolDefinition.name())
206+
.description(toolDefinition.description())
207+
.parameters(params));
208+
tools.add(tool);
209+
} catch (JsonProcessingException ignored) {
210+
}
211+
}
212+
return tools;
175213
}
176214
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModelTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ void testStreamCompletion() throws IOException {
114114
assertThat(deltaList.get(3).getResult().getOutput().getText()).isEqualTo("!");
115115
assertThat(deltaList.get(4).getResult().getOutput().getText()).isEqualTo("");
116116

117+
assertThat(deltaList.get(0).getResult().getMetadata().getFinishReason()).isEqualTo(null);
118+
assertThat(deltaList.get(1).getResult().getMetadata().getFinishReason()).isEqualTo(null);
119+
assertThat(deltaList.get(2).getResult().getMetadata().getFinishReason()).isEqualTo(null);
120+
assertThat(deltaList.get(3).getResult().getMetadata().getFinishReason()).isEqualTo(null);
121+
assertThat(deltaList.get(4).getResult().getMetadata().getFinishReason()).isEqualTo("stop");
122+
117123
Mockito.verify(inputStream, times(1)).close();
118124
}
119125
}

0 commit comments

Comments
 (0)