Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

### ✨ New Functionality

-
- [OpenAI] [Add convenience for tool definition and parsing function calls](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#executing-tool-calls)

### 📈 Improvements

Expand Down
4 changes: 4 additions & 0 deletions foundation-models/openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-jsonSchema</artifactId>
</dependency>
<dependency>
<groupId>io.vavr</groupId>
<artifactId>vavr</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,30 @@ public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoic
return this.withToolChoice(choice.toolChoice);
}

/**
* Sets the tools to be used in the request with convenience class {@code OpenAiTool}.
*
* @param tools the list of tools to be used
* @return a new OpenAiChatCompletionRequest instance with the specified tools
* @throws IllegalArgumentException if the tool type is not supported
* @since 1.7.0
*/
@Nonnull
public OpenAiChatCompletionRequest withOpenAiTools(@Nonnull final List<OpenAiTool> tools) {
return this.withTools(
tools.stream()
.map(
tool -> {
if (tool instanceof OpenAiFunctionTool) {
return ((OpenAiFunctionTool) tool).createChatCompletionTool();
} else {
throw new IllegalArgumentException(
"Unsupported tool type: " + tool.getClass().getName());
}
})
.toList());
}

/**
* Converts the request to a generated model class CreateChatCompletionRequest.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package com.sap.ai.sdk.foundationmodels.openai;

import static com.sap.ai.sdk.foundationmodels.openai.OpenAiUtils.getOpenAiObjectMapper;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.annotations.Beta;
import java.lang.reflect.Type;
import java.util.Map;
import javax.annotation.Nonnull;
import lombok.AllArgsConstructor;
import lombok.Value;
Expand All @@ -22,4 +28,50 @@ public class OpenAiFunctionCall implements OpenAiToolCall {

/** The arguments for the function call, encoded as a JSON string. */
@Nonnull String arguments;

/**
* Parses the arguments, encoded as a JSON string, into a {@code Map<String, Object>}.
*
* @return a map of the arguments
* @throws IllegalArgumentException if parsing fails
* @since 1.7.0
*/
@Nonnull
public Map<String, Object> getArgumentsAsMap() throws IllegalArgumentException {
return parseArguments(new TypeReference<>() {});
}

/**
* Parses the arguments, encoded as a JSON string, into an object of type expected by a function
* tool.
*
* @param tool the function tool the arguments are for
* @return the parsed arguments as an object
* @param <T> the type of object accepted by the function tool
* @throws IllegalArgumentException if parsing arguments fails
* @since 1.7.0
*/
@Nonnull
public <T> T getArgumentsAsObject(@Nonnull final OpenAiFunctionTool tool)
throws IllegalArgumentException {
final var typeRef =
new TypeReference<T>() {
@Override
public Type getType() {
return tool.getRequestModel();
}
};
return parseArguments(typeRef);
}

@Nonnull
private <T> T parseArguments(@Nonnull final TypeReference<T> typeReference)
throws IllegalArgumentException {
try {
return getOpenAiObjectMapper().readValue(getArguments(), typeReference);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(
"Failed to parse JSON string to class " + typeReference.getType(), e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.sap.ai.sdk.foundationmodels.openai;

import static com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum.FUNCTION;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
import com.google.common.annotations.Beta;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Value;
import lombok.With;

/**
* Represents an OpenAI function tool that can be used to define a function call in an OpenAI Chat
* Completion request. This tool generates a JSON schema based on the provided class representing
* the function's request structure.
*
* @see <a href="https://platform.openai.com/docs/guides/gpt/function-calling"/>OpenAI Function
* @since 1.7.0
*/
@Beta
@Value
@With
@Getter(AccessLevel.PACKAGE)
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class OpenAiFunctionTool implements OpenAiTool {

/** The name of the function. */
@Nonnull String name;

/** The model class for function request. */
@Nonnull Class<?> requestModel;

/** An optional description of the function. */
@Nullable String description;

/** An optional flag indicating whether the function parameters should be treated strictly. */
@Nullable Boolean strict;

/**
* Constructs an {@code OpenAiFunctionTool} with the specified name and a model class that
* captures the request to the function.
*
* @param name the name of the function
* @param requestModel the model class for the function request
* @param <T> the type of the request model
*/
public <T> OpenAiFunctionTool(@Nonnull final String name, @Nonnull final Class<T> requestModel) {
this(name, requestModel, null, null);
}

ChatCompletionTool createChatCompletionTool() {
final var objectMapper = new ObjectMapper();
JsonSchema schema = null;
try {
schema = new JsonSchemaGenerator(objectMapper).generateSchema(requestModel);
} catch (JsonMappingException e) {
throw new IllegalArgumentException(
"Could not generate schema for " + requestModel.getTypeName(), e);
}

schema.setId(null);
final var schemaMap =
objectMapper.convertValue(schema, new TypeReference<Map<String, Object>>() {});

final var function =
new FunctionObject()
.name(getName())
.description(getDescription())
.parameters(schemaMap)
.strict(getStrict());
return new ChatCompletionTool().type(FUNCTION).function(function);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.sap.ai.sdk.foundationmodels.openai;

/**
* Represents a tool that can be integrated into an OpenAI Chat Completion request.
*
* @since 1.7.0
*/
public sealed interface OpenAiTool permits OpenAiFunctionTool {}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;

class OpenAiChatCompletionRequestTest {
Expand Down Expand Up @@ -114,4 +116,45 @@ void messageListExternallyUnmodifiable() {
.as("Modifying the original list should not affect the messages in the request object.")
.hasSize(1);
}

@Test
void withOpenAiTools() {
record DummyRequest(String param1, int param2) {}

var request =
new OpenAiChatCompletionRequest(OpenAiMessage.user("Hello, world"))
.withOpenAiTools(
List.of(
new OpenAiFunctionTool("toolA", DummyRequest.class)
.withDescription("descA")
.withStrict(true),
new OpenAiFunctionTool("toolB", String.class)
.withDescription("descB")
.withStrict(false)));

var lowLevelRequest = request.createCreateChatCompletionRequest();
assertThat(lowLevelRequest.getTools()).hasSize(2);

var toolA = lowLevelRequest.getTools().get(0);
assertThat(toolA).isInstanceOf(ChatCompletionTool.class);
assertThat(toolA.getType()).isEqualTo(ChatCompletionTool.TypeEnum.FUNCTION);
assertThat(toolA.getFunction().getName()).isEqualTo("toolA");
assertThat(toolA.getFunction().getDescription()).isEqualTo("descA");
assertThat(toolA.getFunction().isStrict()).isTrue();
assertThat(toolA.getFunction().getParameters())
.isEqualTo(
Map.of(
"properties",
Map.of("param1", Map.of("type", "string"), "param2", Map.of("type", "integer")),
"type",
"object"));

var toolB = lowLevelRequest.getTools().get(1);
assertThat(toolB).isInstanceOf(ChatCompletionTool.class);
assertThat(toolB.getType()).isEqualTo(ChatCompletionTool.TypeEnum.FUNCTION);
assertThat(toolB.getFunction().getName()).isEqualTo("toolB");
assertThat(toolB.getFunction().getDescription()).isEqualTo("descB");
assertThat(toolB.getFunction().isStrict()).isFalse();
assertThat(toolB.getFunction().getParameters()).isEqualTo(Map.of("type", "string"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.sap.ai.sdk.foundationmodels.openai;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import org.junit.jupiter.api.Test;

class OpenAiToolCallTest {
private static final OpenAiFunctionCall VALID_FUNCTION_CALL =
new OpenAiFunctionCall("1", "functionName", "{\"key\":\"value\"}");
private static final OpenAiFunctionCall INVALID_FUNCTION_CALL =
new OpenAiFunctionCall("1", "functionName", "{invalid-json}");

private static final OpenAiFunctionTool FUNCTION_TOOL =
new OpenAiFunctionTool("functionName", DummyRequest.class);

record DummyRequest(String key) {}

@Test
void getArgumentsAsMapParsesValidJson() {
var result = VALID_FUNCTION_CALL.getArgumentsAsMap();
assertThat(result).containsEntry("key", "value");
}

@Test
void getArgumentsAsMapThrowsOnInvalidJson() {
assertThatThrownBy(INVALID_FUNCTION_CALL::getArgumentsAsMap)
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Failed to parse JSON string");
}

@Test
void getArgumentsAsObjectParsesValidJson() {
var result = (DummyRequest) VALID_FUNCTION_CALL.getArgumentsAsObject(FUNCTION_TOOL);
assertThat(result).isInstanceOf(DummyRequest.class);
assertThat(result.key()).isEqualTo("value");
}

@Test
void getArgumentsAsObjectThrowsOnInvalidJson() {
assertThatThrownBy(() -> INVALID_FUNCTION_CALL.getArgumentsAsObject(FUNCTION_TOOL))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Failed to parse JSON string");
}
}
Loading