Skip to content

Commit 3667b1e

Browse files
committed
Reverted the intermediate result class; OpenAiTool#execute will directly return the tool message list
1 parent ce3ba95 commit 3667b1e

File tree

3 files changed

+34
-53
lines changed

3 files changed

+34
-53
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiTool.java

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import lombok.AccessLevel;
2929
import lombok.AllArgsConstructor;
3030
import lombok.Getter;
31-
import lombok.RequiredArgsConstructor;
3231
import lombok.Setter;
3332
import lombok.Value;
3433
import lombok.With;
@@ -163,14 +162,34 @@ private static SchemaGenerator createSchemaGenerator() {
163162
*
164163
* @param tools the list of tools to execute
165164
* @param msg the assistant message containing a list of tool calls with arguments
166-
* @return a result object that contains the list of tool messages with the results
165+
* @return The list of tool messages with the results.
167166
*/
168167
@Beta
169168
@Nonnull
170-
public static Execution execute(
169+
public static List<OpenAiToolMessage> execute(
171170
@Nonnull final List<OpenAiTool> tools, @Nonnull final OpenAiAssistantMessage msg) {
172-
final var result = new LinkedHashMap<OpenAiFunctionCall, Object>();
171+
final var toolResults = executeInternal(tools, msg);
172+
final var result = new ArrayList<OpenAiToolMessage>();
173+
for (final var entry : toolResults.entrySet()) {
174+
final var functionCall = entry.getKey().getId();
175+
final var serializedValue = serializeObject(entry.getValue());
176+
result.add(OpenAiMessage.tool(serializedValue, functionCall));
177+
}
178+
return result;
179+
}
173180

181+
/**
182+
* Executes the given tool calls with the provided tools and returns the results as a list of
183+
* {@link OpenAiToolMessage} containing execution results encoded as JSON string.
184+
*
185+
* @param tools the list of tools to execute
186+
* @param msg the assistant message containing a list of tool calls with arguments
187+
* @return a map that contains the function calls and their respective tool results.
188+
*/
189+
@Nonnull
190+
protected static Map<OpenAiFunctionCall, Object> executeInternal(
191+
@Nonnull final List<OpenAiTool> tools, @Nonnull final OpenAiAssistantMessage msg) {
192+
final var result = new LinkedHashMap<OpenAiFunctionCall, Object>();
174193
final var toolMap = tools.stream().collect(Collectors.toMap(OpenAiTool::getName, identity()));
175194
for (final OpenAiToolCall toolCall : msg.toolCalls()) {
176195
if (toolCall instanceof OpenAiFunctionCall functionCall) {
@@ -183,7 +202,7 @@ public static Execution execute(
183202
result.put(functionCall, toolResult);
184203
}
185204
}
186-
return new Execution(result);
205+
return result;
187206
}
188207

189208
@Nonnull
@@ -202,31 +221,4 @@ private static String serializeObject(@Nonnull final Object obj) throws IllegalA
202221
throw new IllegalArgumentException("Failed to serialize object to JSON", e);
203222
}
204223
}
205-
206-
/**
207-
* Represents the result of executing a tool call, containing the results of the function calls.
208-
*/
209-
@RequiredArgsConstructor
210-
@Beta
211-
public static class Execution {
212-
@Getter @Beta @Nonnull private final Map<OpenAiFunctionCall, Object> results;
213-
214-
/**
215-
* Creates a new list of serialized OpenAI tool messages.
216-
*
217-
* @return the list of serialized OpenAI tool messages.
218-
* @throws IllegalArgumentException if the tool results cannot be serialized to JSON
219-
*/
220-
@Beta
221-
@Nonnull
222-
public List<OpenAiToolMessage> getMessages() {
223-
final var result = new ArrayList<OpenAiToolMessage>();
224-
for (final var entry : getResults().entrySet()) {
225-
final var functionCall = entry.getKey().getId();
226-
final var serializedValue = serializeObject(entry.getValue());
227-
result.add(OpenAiMessage.tool(serializedValue, functionCall));
228-
}
229-
return result;
230-
}
231-
}
232224
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolTest.java

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import java.util.Collections;
77
import java.util.List;
8-
import java.util.Map;
98
import java.util.function.Function;
109
import lombok.EqualsAndHashCode;
1110
import org.junit.jupiter.api.Test;
@@ -64,17 +63,11 @@ void executeToolsValid() {
6463
.withArgument(Dummy.Request.class)
6564
.withName("functionA");
6665
final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_A));
67-
final var execution = OpenAiTool.execute(List.of(toolA), assistMsg);
66+
final var toolMsgs = OpenAiTool.execute(List.of(toolA), assistMsg);
6867

69-
final var results = execution.getResults();
70-
assertThat(results)
71-
.hasSize(1)
72-
.containsExactly(Map.entry(FUNCTION_CALL_A, new Dummy.Response("value")));
73-
74-
final var toolMsg = execution.getMessages();
75-
assertThat(toolMsg).hasSize(1);
76-
assertThat(toolMsg.get(0).toolCallId()).isEqualTo("1");
77-
assertThat(((OpenAiTextItem) toolMsg.get(0).content().items().get(0)).text())
68+
assertThat(toolMsgs).hasSize(1);
69+
assertThat(toolMsgs.get(0).toolCallId()).isEqualTo("1");
70+
assertThat(((OpenAiTextItem) toolMsgs.get(0).content().items().get(0)).text())
7871
.isEqualTo("{\"toolMsg\":\"value\"}");
7972
}
8073

@@ -85,9 +78,8 @@ void executeToolsNoMatchingCall() {
8578
.withArgument(Dummy.Request.class)
8679
.withName("functionA");
8780
final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_B));
88-
final var executions = OpenAiTool.execute(List.of(toolA), assistMsg);
89-
assertThat(executions.getResults()).isEmpty();
90-
assertThat(executions.getMessages()).isEmpty();
81+
final var toolMsgs = OpenAiTool.execute(List.of(toolA), assistMsg);
82+
assertThat(toolMsgs).isEmpty();
9183
}
9284

9385
@Test
@@ -106,12 +98,8 @@ class NonSerializableResponse {
10698
final var toolA =
10799
OpenAiTool.forFunction(badF).withArgument(Dummy.Request.class).withName("functionA");
108100
final var assistMsg = new OpenAiAssistantMessage(EMPTY_MSG_CONTENT, List.of(FUNCTION_CALL_A));
109-
final var executions = OpenAiTool.execute(List.of(toolA), assistMsg);
110-
111-
assertThat(executions.getResults())
112-
.containsExactly(Map.entry(FUNCTION_CALL_A, new NonSerializableResponse("value")));
113101

114-
assertThatThrownBy(executions::getMessages)
102+
assertThatThrownBy(() -> OpenAiTool.execute(List.of(toolA), assistMsg))
115103
.isInstanceOf(IllegalArgumentException.class)
116104
.hasMessageContaining("Failed to serialize object to JSON");
117105
}

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem;
1616
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1717
import com.sap.ai.sdk.foundationmodels.openai.OpenAiTool;
18+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolMessage;
1819
import java.util.ArrayList;
1920
import java.util.List;
2021
import java.util.stream.Stream;
@@ -112,11 +113,11 @@ public OpenAiChatCompletionResponse chatCompletionToolExecution(
112113

113114
// 3. Execute the tool call for given tools
114115
final OpenAiAssistantMessage assistantMessage = response.getMessage();
115-
final var toolResults = OpenAiTool.execute(tools, assistantMessage);
116+
final List<OpenAiToolMessage> toolMessages = OpenAiTool.execute(tools, assistantMessage);
116117

117118
// 4. Return the results so that the model can incorporate them into the final response.
118119
messages.add(assistantMessage);
119-
messages.addAll(toolResults.getMessages());
120+
messages.addAll(toolMessages);
120121

121122
return client.chatCompletion(request.withMessages(messages));
122123
}

0 commit comments

Comments
 (0)