Skip to content

Commit bcb42c8

Browse files
committed
With purely generics
1 parent 36b0189 commit bcb42c8

File tree

9 files changed

+149
-155
lines changed

9 files changed

+149
-155
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.google.common.annotations.Beta;
44
import com.google.common.collect.Lists;
55
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions;
6+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
67
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
78
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
89
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat;
@@ -122,9 +123,7 @@ public class OpenAiChatCompletionRequest {
122123
@Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat;
123124

124125
/** List of tools that the model may invoke during the completion. */
125-
@Getter(value = AccessLevel.PACKAGE)
126-
@Nullable
127-
List<OpenAiFunctionTool<?, ?>> tools;
126+
@Nullable List<ChatCompletionTool> tools;
128127

129128
/** Option to control which tool is invoked by the model. */
130129
@With(AccessLevel.PRIVATE)
@@ -293,18 +292,7 @@ public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoic
293292
*/
294293
@Nonnull
295294
public OpenAiChatCompletionRequest withOpenAiTools(@Nonnull final List<OpenAiTool> tools) {
296-
return this.withTools(
297-
tools.stream()
298-
.map(
299-
tool -> {
300-
if (tool instanceof OpenAiFunctionTool) {
301-
return ((OpenAiFunctionTool<?, ?>) tool).createChatCompletionTool();
302-
} else {
303-
throw new IllegalArgumentException(
304-
"Unsupported tool type: " + tool.getClass().getName());
305-
}
306-
})
307-
.toList());
295+
return this.withTools(tools.stream().map(OpenAiTool::createChatCompletionTool).toList());
308296
}
309297

310298
/**
@@ -337,7 +325,7 @@ CreateChatCompletionRequest createCreateChatCompletionRequest() {
337325
request.seed(this.seed);
338326
request.streamOptions(this.streamOptions);
339327
request.responseFormat(this.responseFormat);
340-
request.tools(this.tools.createChatCompletionTool());
328+
request.tools(this.tools);
341329
request.toolChoice(this.toolChoice);
342330
request.functionCall(null);
343331
request.functions(null);

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66

77
import com.fasterxml.jackson.core.type.TypeReference;
88
import com.google.common.annotations.Beta;
9-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
109
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
1110
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
1211
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner;
12+
import java.util.ArrayList;
1313
import java.util.List;
1414
import java.util.Objects;
1515
import javax.annotation.Nonnull;
16-
import javax.annotation.Nullable;
1716
import lombok.RequiredArgsConstructor;
1817
import lombok.Setter;
1918
import lombok.Value;
@@ -31,8 +30,6 @@ public class OpenAiChatCompletionResponse {
3130
/** The original response from the OpenAI API. */
3231
@Nonnull CreateChatCompletionResponse originalResponse;
3332

34-
@Nullable List<ChatCompletionTool> functions;
35-
3633
/**
3734
* Gets the token usage from the original response.
3835
*
@@ -103,29 +100,19 @@ public OpenAiAssistantMessage getMessage() {
103100
}
104101

105102
public <T, R> List<OpenAiToolMessage> executeTools(List<OpenAiTool> tools) {
106-
return getMessage().toolCalls().stream()
107-
.filter(toolCall -> toolCall instanceof OpenAiFunctionCall)
108-
.map(toolCall -> (OpenAiFunctionCall) toolCall)
109-
.map(
110-
functionCall -> {
111-
OpenAiFunctionTool<T, R> request = findFunction(tools, functionCall.getName());
112-
T arguments = functionCall.parseArguments(new TypeReference<T>() {});
113-
R response = request.call(arguments);
114-
return OpenAiMessage.tool(response, functionCall.getId());
115-
})
116-
.toList();
117-
}
103+
var toolMessages = new ArrayList<OpenAiToolMessage>();
118104

119-
@Nullable
120-
private <T, R> OpenAiFunctionTool<T, R> findFunction(List<OpenAiTool> tools, String name) {
121-
if (functions == null) {
122-
return null;
123-
}
124-
for (OpenAiTool tool : tools) {
125-
if (tool instanceof OpenAiFunctionTool<?, ?> function && function.getName().equals(name)) {
126-
return (OpenAiFunctionTool<T, R>) function;
105+
for (var toolCall : getMessage().toolCalls()) {
106+
if (toolCall instanceof OpenAiFunctionCall functionCall) {
107+
for (OpenAiTool<T, R> tool : tools) {
108+
if (tool.getName().equals(functionCall.getName())) {
109+
T arguments = functionCall.getArgumentsAsObject(new TypeReference<T>() {});
110+
R response = tool.call(arguments);
111+
toolMessages.add(OpenAiMessage.tool(response.toString(), functionCall.getId()));
112+
}
113+
}
127114
}
128115
}
129-
return null;
116+
return toolMessages;
130117
}
131118
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ public OpenAiChatCompletionResponse chatCompletion(
159159
@Nonnull final OpenAiChatCompletionRequest request) throws OpenAiClientException {
160160
warnIfUnsupportedUsage();
161161
return new OpenAiChatCompletionResponse(
162-
chatCompletion(request.createCreateChatCompletionRequest()), request.getTools());
162+
chatCompletion(request.createCreateChatCompletionRequest()));
163163
}
164164

165165
/**

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionCall.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.fasterxml.jackson.core.JsonProcessingException;
66
import com.fasterxml.jackson.core.type.TypeReference;
77
import com.google.common.annotations.Beta;
8+
import java.util.Map;
89
import javax.annotation.Nonnull;
910
import lombok.AllArgsConstructor;
1011
import lombok.Value;
@@ -28,20 +29,25 @@ public class OpenAiFunctionCall implements OpenAiToolCall {
2829
@Nonnull String arguments;
2930

3031
/**
31-
* Parses the arguments, encoded as a JSON string, into an object of type expected by a function
32-
* tool.
32+
* Parses the arguments, encoded as a JSON string, into a {@code Map<String, Object>}.
3333
*
34-
* @param request the type of the class
35-
* @return the parsed arguments as an object
34+
* @return a map of the arguments
3635
* @throws IllegalArgumentException if parsing fails
3736
* @since 1.7.0
3837
*/
3938
@Nonnull
40-
<T> T parseArguments(@Nonnull final TypeReference<T> request) throws IllegalArgumentException {
39+
public Map<String, Object> getArgumentsAsMap() throws IllegalArgumentException {
40+
return getArgumentsAsObject(new TypeReference<>() {});
41+
}
42+
43+
@Nonnull
44+
<T> T getArgumentsAsObject(@Nonnull final TypeReference<T> typeReference)
45+
throws IllegalArgumentException {
4146
try {
42-
return getOpenAiObjectMapper().readValue(arguments, request);
47+
return getOpenAiObjectMapper().readValue(arguments, typeReference);
4348
} catch (JsonProcessingException e) {
44-
throw new IllegalArgumentException("Failed to parse JSON string to class " + request, e);
49+
throw new IllegalArgumentException(
50+
"Failed to parse JSON string to class " + typeReference, e);
4551
}
4652
}
4753
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiFunctionTool.java

Lines changed: 0 additions & 89 deletions
This file was deleted.
Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,89 @@
11
package com.sap.ai.sdk.foundationmodels.openai;
22

3+
import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION;
4+
5+
import com.fasterxml.jackson.core.type.TypeReference;
6+
import com.fasterxml.jackson.databind.JsonMappingException;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
9+
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
10+
import com.google.common.annotations.Beta;
11+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
12+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
13+
import java.util.Map;
14+
import java.util.function.Function;
15+
import javax.annotation.Nonnull;
16+
import javax.annotation.Nullable;
17+
import lombok.AccessLevel;
18+
import lombok.AllArgsConstructor;
19+
import lombok.Getter;
20+
import lombok.Value;
21+
import lombok.With;
22+
323
/**
4-
* Represents a tool that can be integrated into an OpenAI Chat Completion request.
24+
* Represents an OpenAI function tool that can be used to define a function call in an OpenAI Chat
25+
* Completion request. This tool generates a JSON schema based on the provided class representing
26+
* the function's request structure.
527
*
28+
* @see <a href="https://platform.openai.com/docs/guides/gpt/function-calling"/>OpenAI Function
629
* @since 1.7.0
730
*/
8-
public sealed interface OpenAiTool permits OpenAiFunctionTool {}
31+
@Beta
32+
@Value
33+
@With
34+
@Getter(AccessLevel.PACKAGE)
35+
@AllArgsConstructor(access = AccessLevel.PRIVATE)
36+
public class OpenAiTool<T, R> {
37+
38+
/** The name of the function. */
39+
@Nonnull String name;
40+
41+
/** The model class for function request. */
42+
@Nonnull Function<T, R> function;
43+
44+
/** An optional description of the function. */
45+
@Nullable String description;
46+
47+
/** An optional flag indicating whether the function parameters should be treated strictly. */
48+
@Nullable Boolean strict;
49+
50+
/**
51+
* Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that
52+
* captures the request to the function.
53+
*
54+
* @param name the name of the function
55+
* @param function the model class for function request
56+
*/
57+
public OpenAiTool(@Nonnull final String name, @Nonnull final Function<T, R> function) {
58+
this(name, function, null, null);
59+
}
60+
61+
@Nonnull
62+
R call(@Nonnull final T request) {
63+
return function.apply(request);
64+
}
65+
66+
ChatCompletionTool createChatCompletionTool() {
67+
final var objectMapper = new ObjectMapper();
68+
JsonSchema schema = null;
69+
try {
70+
schema =
71+
new JsonSchemaGenerator(objectMapper)
72+
.generateSchema(new TypeReference<T>() {}.getClass());
73+
} catch (JsonMappingException e) {
74+
throw new IllegalArgumentException("Could not generate schema for " + name, e);
75+
}
76+
77+
schema.setId(null);
78+
final var schemaMap =
79+
objectMapper.convertValue(schema, new TypeReference<Map<String, Object>>() {});
80+
81+
final var function =
82+
new FunctionObject()
83+
.name(getName())
84+
.description(getDescription())
85+
.parameters(schemaMap)
86+
.strict(getStrict());
87+
return new ChatCompletionTool().type(FUNCTION).function(function);
88+
}
89+
}

0 commit comments

Comments
 (0)