Skip to content

Commit 0c75937

Browse files
authored
feat: [OpenAI] Tool Choice Convenience (#330)
* Adding tool choice convenience methods in request class * refactor imports --------- Co-authored-by: Roshin Rajan Panackal <[email protected]>
1 parent 0643663 commit 0c75937

File tree

4 files changed

+100
-21
lines changed

4 files changed

+100
-21
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import com.google.common.annotations.Beta;
44
import com.google.common.collect.Lists;
5+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
6+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
57
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions;
68
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
79
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
@@ -126,7 +128,9 @@ public class OpenAiChatCompletionRequest {
126128
@Nullable List<ChatCompletionTool> tools;
127129

128130
/** Option to control which tool is invoked by the model. */
129-
@Nullable ChatCompletionToolChoiceOption toolChoice;
131+
@With(AccessLevel.PRIVATE)
132+
@Nullable
133+
ChatCompletionToolChoiceOption toolChoice;
130134

131135
/**
132136
* Creates an OpenAiChatCompletionPrompt with string as user message.
@@ -248,6 +252,51 @@ public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs)
248252
this.toolChoice);
249253
}
250254

255+
/**
256+
* Only message generation will be performed without calling any tool.
257+
*
258+
* @return the current OpenAiChatCompletionRequest instance.
259+
*/
260+
@Nonnull
261+
public OpenAiChatCompletionRequest withToolChoiceNone() {
262+
return this.withToolChoice(ChatCompletionToolChoiceOption.create("none"));
263+
}
264+
265+
/**
266+
* The model may decide whether to call a (one or more) tool.
267+
*
268+
* @return the current OpenAiChatCompletionRequest instance.
269+
*/
270+
@Nonnull
271+
public OpenAiChatCompletionRequest withToolChoiceOptional() {
272+
return this.withToolChoice(ChatCompletionToolChoiceOption.create("auto"));
273+
}
274+
275+
/**
276+
* The model must call one or more tools as part of its processing.
277+
*
278+
* @return the current OpenAiChatCompletionRequest instance.
279+
*/
280+
@Nonnull
281+
public OpenAiChatCompletionRequest withToolChoiceRequired() {
282+
return this.withToolChoice(ChatCompletionToolChoiceOption.create("required"));
283+
}
284+
285+
/**
286+
* The model must call the function specified by {@code functionName}.
287+
*
288+
* @param functionName the name of the function that must be called.
289+
* @return the current OpenAiChatCompletionRequest instance.
290+
*/
291+
@Nonnull
292+
public OpenAiChatCompletionRequest withToolChoiceFunction(@Nonnull final String functionName) {
293+
return this.withToolChoice(
294+
ChatCompletionToolChoiceOption.create(
295+
new ChatCompletionNamedToolChoice()
296+
.type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION)
297+
.function(new ChatCompletionNamedToolChoiceFunction().name(functionName))));
298+
}
299+
251300
/**
252301
* Converts the request to a generated model class CreateChatCompletionRequest.
253302
*

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage;
66
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent;
7+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
78
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop;
89
import java.math.BigDecimal;
910
import java.util.List;
@@ -45,4 +46,51 @@ void createWithExistingRequest() {
4546
assertThat(lowlevelRequest.getSeed()).isEqualTo(123);
4647
assertThat(lowlevelRequest.getTemperature()).isEqualTo(BigDecimal.valueOf(0.5));
4748
}
49+
50+
@Test
51+
void withToolChoiceNone() {
52+
OpenAiChatCompletionRequest request =
53+
new OpenAiChatCompletionRequest("message").withToolChoiceNone();
54+
55+
var lowLevelRequest = request.createCreateChatCompletionRequest();
56+
var choice =
57+
((ChatCompletionToolChoiceOption.InnerString) lowLevelRequest.getToolChoice()).value();
58+
assertThat(choice).isEqualTo("none");
59+
}
60+
61+
@Test
62+
void withToolChoiceOptional() {
63+
OpenAiChatCompletionRequest request =
64+
new OpenAiChatCompletionRequest("message").withToolChoiceOptional();
65+
66+
var lowLevelRequest = request.createCreateChatCompletionRequest();
67+
var choice =
68+
((ChatCompletionToolChoiceOption.InnerString) lowLevelRequest.getToolChoice()).value();
69+
assertThat(choice).isEqualTo("auto");
70+
}
71+
72+
@Test
73+
void withToolChoiceRequired() {
74+
OpenAiChatCompletionRequest request =
75+
new OpenAiChatCompletionRequest("message").withToolChoiceRequired();
76+
77+
var lowLevelRequest = request.createCreateChatCompletionRequest();
78+
var choice =
79+
((ChatCompletionToolChoiceOption.InnerString) lowLevelRequest.getToolChoice()).value();
80+
assertThat(choice).isEqualTo("required");
81+
}
82+
83+
@Test
84+
void withToolChoiceFunction() {
85+
OpenAiChatCompletionRequest request =
86+
new OpenAiChatCompletionRequest("message").withToolChoiceFunction("functionName");
87+
88+
var lowLevelRequest = request.createCreateChatCompletionRequest();
89+
var choice =
90+
((ChatCompletionToolChoiceOption.InnerChatCompletionNamedToolChoice)
91+
lowLevelRequest.getToolChoice())
92+
.value();
93+
assertThat(choice.getType().getValue()).isEqualTo("function");
94+
assertThat(choice.getFunction().getName()).isEqualTo("functionName");
95+
}
4896
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,8 @@
2525
import static org.mockito.Mockito.times;
2626
import static org.mockito.Mockito.when;
2727

28-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
29-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
3028
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessageRole;
3129
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
32-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
3330
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ContentFilterPromptResults;
3431
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
3532
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionStreamResponseChoicesInner;
@@ -530,17 +527,11 @@ void chatCompletionTool() {
530527
final var tool =
531528
new ChatCompletionTool().type(ChatCompletionTool.TypeEnum.FUNCTION).function(function);
532529

533-
final var toolChoice =
534-
ChatCompletionToolChoiceOption.create(
535-
new ChatCompletionNamedToolChoice()
536-
.type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION)
537-
.function(new ChatCompletionNamedToolChoiceFunction().name("fibonacci")));
538-
539530
final var request =
540531
new OpenAiChatCompletionRequest(
541532
"A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?")
542533
.withTools(List.of(tool))
543-
.withToolChoice(toolChoice);
534+
.withToolChoiceFunction("fibonacci");
544535

545536
var response = client.chatCompletion(request).getOriginalResponse();
546537

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
1313
import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem;
1414
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
15-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
16-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
1715
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
18-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
1916
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreate200Response;
2017
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequest;
2118
import com.sap.ai.sdk.foundationmodels.openai.generated.model.EmbeddingsCreateRequestInput;
@@ -103,18 +100,12 @@ public OpenAiChatCompletionResponse chatCompletionTools(final int months) {
103100

104101
final var tool = new ChatCompletionTool().type(FUNCTION).function(function);
105102

106-
final var toolChoice =
107-
ChatCompletionToolChoiceOption.create(
108-
new ChatCompletionNamedToolChoice()
109-
.type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION)
110-
.function(new ChatCompletionNamedToolChoiceFunction().name("fibonacci")));
111-
112103
final var request =
113104
new OpenAiChatCompletionRequest(
114105
"A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?"
115106
.formatted(months))
116107
.withTools(List.of(tool))
117-
.withToolChoice(toolChoice);
108+
.withToolChoiceFunction("fibonacci");
118109

119110
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
120111
}

0 commit comments

Comments
 (0)