Skip to content

Commit 8a7b914

Browse files
newtorka-d
andauthored
chore: Consolidate OpenAI convenience methods to set tool-choice (#351)
* Initial * Apply suggestion --------- Co-authored-by: Alexander Dümont <[email protected]>
1 parent 9de1c1b commit 8a7b914

File tree

5 files changed

+65
-44
lines changed

5 files changed

+65
-44
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import com.google.common.annotations.Beta;
44
import com.google.common.collect.Lists;
5-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
6-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
75
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions;
86
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
97
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
@@ -253,48 +251,24 @@ public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs)
253251
}
254252

255253
/**
256-
* Only message generation will be performed without calling any tool.
254+
* Define the model behavior towards calling functions.
257255
*
258-
* @return the current OpenAiChatCompletionRequest instance.
259-
*/
260-
@Nonnull
261-
public OpenAiChatCompletionRequest withToolChoiceNone() {
262-
return this.withToolChoice(ChatCompletionToolChoiceOption.create("none"));
263-
}
264-
265-
/**
266-
* The model may decide whether to call a (one or more) tool.
267-
*
268-
* @return the current OpenAiChatCompletionRequest instance.
269-
*/
270-
@Nonnull
271-
public OpenAiChatCompletionRequest withToolChoiceOptional() {
272-
return this.withToolChoice(ChatCompletionToolChoiceOption.create("auto"));
273-
}
274-
275-
/**
276-
* The model must call one or more tools as part of its processing.
256+
* <p>Example:
277257
*
278-
* @return the current OpenAiChatCompletionRequest instance.
279-
*/
280-
@Nonnull
281-
public OpenAiChatCompletionRequest withToolChoiceRequired() {
282-
return this.withToolChoice(ChatCompletionToolChoiceOption.create("required"));
283-
}
284-
285-
/**
286-
* The model must call the function specified by {@code functionName}.
258+
* <ul>
259+
* <li><code>.withToolChoice(OpenAiToolChoice.NONE)</code>
260+
* <li><code>.withToolChoice(OpenAiToolChoice.OPTIONAL)</code>
261+
* <li><code>.withToolChoice(OpenAiToolChoice.REQUIRED)</code>
262+
* <li><code>.withToolChoice(OpenAiToolChoice.function("fibonacci")</code>
263+
* </ul>
287264
*
288-
* @param functionName the name of the function that must be called.
265+
* @param choice the generic tool choice.
289266
* @return the current OpenAiChatCompletionRequest instance.
290267
*/
291268
@Nonnull
292-
public OpenAiChatCompletionRequest withToolChoiceFunction(@Nonnull final String functionName) {
293-
return this.withToolChoice(
294-
ChatCompletionToolChoiceOption.create(
295-
new ChatCompletionNamedToolChoice()
296-
.type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION)
297-
.function(new ChatCompletionNamedToolChoiceFunction().name(functionName))));
269+
@Tolerate
270+
public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoice choice) {
271+
return this.withToolChoice(choice.toolChoice);
298272
}
299273

300274
/**
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoice;
4+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionNamedToolChoiceFunction;
5+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
6+
import javax.annotation.Nonnull;
7+
import lombok.AccessLevel;
8+
import lombok.RequiredArgsConstructor;
9+
10+
/**
11+
* OpenAi ToolChoice to specify whether to call which tool.
12+
*
13+
* @since 1.4.0
14+
*/
15+
@RequiredArgsConstructor(access = AccessLevel.PROTECTED)
16+
public class OpenAiToolChoice {
17+
@Nonnull final ChatCompletionToolChoiceOption toolChoice;
18+
19+
/** Only message generation will be performed without calling any tool. */
20+
public static final OpenAiToolChoice NONE =
21+
new OpenAiToolChoice(ChatCompletionToolChoiceOption.create("none"));
22+
23+
/** The model may decide whether to call a (one or more) tool. */
24+
public static final OpenAiToolChoice OPTIONAL =
25+
new OpenAiToolChoice(ChatCompletionToolChoiceOption.create("auto"));
26+
27+
/** The model must call one or more tools as part of its processing. */
28+
public static final OpenAiToolChoice REQUIRED =
29+
new OpenAiToolChoice(ChatCompletionToolChoiceOption.create("required"));
30+
31+
/**
32+
* The model must call the function specified by {@code functionName}.
33+
*
34+
* @param functionName the name of the function that must be called.
35+
* @return the OpenAI tool choice.
36+
*/
37+
@Nonnull
38+
public static OpenAiToolChoice function(@Nonnull final String functionName) {
39+
return new OpenAiToolChoice(
40+
ChatCompletionToolChoiceOption.create(
41+
new ChatCompletionNamedToolChoice()
42+
.type(ChatCompletionNamedToolChoice.TypeEnum.FUNCTION)
43+
.function(new ChatCompletionNamedToolChoiceFunction().name(functionName))));
44+
}
45+
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void createWithExistingRequest() {
5050
@Test
5151
void withToolChoiceNone() {
5252
OpenAiChatCompletionRequest request =
53-
new OpenAiChatCompletionRequest("message").withToolChoiceNone();
53+
new OpenAiChatCompletionRequest("message").withToolChoice(OpenAiToolChoice.NONE);
5454

5555
var lowLevelRequest = request.createCreateChatCompletionRequest();
5656
var choice =
@@ -61,7 +61,7 @@ void withToolChoiceNone() {
6161
@Test
6262
void withToolChoiceOptional() {
6363
OpenAiChatCompletionRequest request =
64-
new OpenAiChatCompletionRequest("message").withToolChoiceOptional();
64+
new OpenAiChatCompletionRequest("message").withToolChoice(OpenAiToolChoice.OPTIONAL);
6565

6666
var lowLevelRequest = request.createCreateChatCompletionRequest();
6767
var choice =
@@ -72,7 +72,7 @@ void withToolChoiceOptional() {
7272
@Test
7373
void withToolChoiceRequired() {
7474
OpenAiChatCompletionRequest request =
75-
new OpenAiChatCompletionRequest("message").withToolChoiceRequired();
75+
new OpenAiChatCompletionRequest("message").withToolChoice(OpenAiToolChoice.REQUIRED);
7676

7777
var lowLevelRequest = request.createCreateChatCompletionRequest();
7878
var choice =
@@ -83,7 +83,8 @@ void withToolChoiceRequired() {
8383
@Test
8484
void withToolChoiceFunction() {
8585
OpenAiChatCompletionRequest request =
86-
new OpenAiChatCompletionRequest("message").withToolChoiceFunction("functionName");
86+
new OpenAiChatCompletionRequest("message")
87+
.withToolChoice(OpenAiToolChoice.function("functionName"));
8788

8889
var lowLevelRequest = request.createCreateChatCompletionRequest();
8990
var choice =

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientGeneratedTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ void chatCompletionTool() {
527527
new OpenAiChatCompletionRequest(
528528
"A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after 12 months?")
529529
.withTools(List.of(tool))
530-
.withToolChoiceFunction("fibonacci");
530+
.withToolChoice(OpenAiToolChoice.function("fibonacci"));
531531

532532
var response = client.chatCompletion(request).getOriginalResponse();
533533

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/services/OpenAiServiceV2.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import com.sap.ai.sdk.foundationmodels.openai.OpenAiEmbeddingResponse;
1515
import com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem;
1616
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
17+
import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolChoice;
1718
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
1819
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
1920
import java.util.List;
@@ -104,7 +105,7 @@ public OpenAiChatCompletionResponse chatCompletionTools(final int months) {
104105
"A pair of rabbits is placed in a field. Each month, every pair produces one new pair, starting from the second month. How many rabbits will there be after %s months?"
105106
.formatted(months))
106107
.withTools(List.of(tool))
107-
.withToolChoiceFunction("fibonacci");
108+
.withToolChoice(OpenAiToolChoice.function("fibonacci"));
108109

109110
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
110111
}

0 commit comments

Comments
 (0)