Skip to content

Commit 9a5aec3

Browse files
committed
First version
- Mostly tested - API design complete - Function Call as message content item - Non nullability of `Message.content()`
1 parent 9ad7225 commit 9a5aec3

File tree

11 files changed

+219
-9
lines changed

11 files changed

+219
-9
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import static lombok.AccessLevel.PACKAGE;
44

55
import com.google.common.annotations.Beta;
6+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall;
7+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCallFunction;
68
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessage;
79
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessageContent;
10+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ToolCallType;
811
import java.util.List;
912
import javax.annotation.Nonnull;
1013
import lombok.AllArgsConstructor;
@@ -21,7 +24,7 @@
2124
@Value
2225
@Accessors(fluent = true)
2326
@AllArgsConstructor(access = PACKAGE)
24-
class OpenAiAssistantMessage implements OpenAiMessage {
27+
public class OpenAiAssistantMessage implements OpenAiMessage {
2528

2629
/** The role associated with this message. */
2730
@Nonnull String role = "assistant";
@@ -32,24 +35,52 @@ class OpenAiAssistantMessage implements OpenAiMessage {
3235
OpenAiMessageContent content;
3336

3437
/**
35-
* Creates a new assistant message with the given single message.
38+
* Creates a new assistant message with the given single message as text content.
3639
*
3740
* @param singleMessage the message.
3841
*/
3942
OpenAiAssistantMessage(@Nonnull final String singleMessage) {
4043
this(new OpenAiMessageContent(List.of(new OpenAiTextItem(singleMessage))));
4144
}
4245

46+
@Nonnull
47+
public List<OpenAiToolCallItem> getToolCalls() {
48+
return this.content().items().stream()
49+
.filter(item -> item instanceof OpenAiToolCallItem)
50+
.map(item -> (OpenAiToolCallItem) item)
51+
.toList();
52+
}
53+
4354
/**
4455
* Converts the message to a serializable object.
4556
*
4657
* @return the corresponding {@code ChatCompletionRequestAssistantMessage} object.
4758
*/
4859
@Nonnull
4960
ChatCompletionRequestAssistantMessage createChatCompletionRequestMessage() {
50-
final var textItem = (OpenAiTextItem) this.content().items().get(0);
51-
return new ChatCompletionRequestAssistantMessage()
52-
.role(ChatCompletionRequestAssistantMessage.RoleEnum.fromValue(role()))
53-
.content(ChatCompletionRequestAssistantMessageContent.create(textItem.text()));
61+
var message =
62+
new ChatCompletionRequestAssistantMessage()
63+
.role(ChatCompletionRequestAssistantMessage.RoleEnum.fromValue(role()));
64+
65+
for (var item : content().items()) {
66+
if (item instanceof OpenAiTextItem textItem) {
67+
message.content(ChatCompletionRequestAssistantMessageContent.create(textItem.text()));
68+
} else if (item instanceof OpenAiFunctionCallItem functionItem) {
69+
70+
var functionCall =
71+
new ChatCompletionMessageToolCallFunction()
72+
.name(functionItem.getName())
73+
.arguments(functionItem.getArguments());
74+
75+
var toolCall =
76+
new ChatCompletionMessageToolCall()
77+
.type(ToolCallType.FUNCTION)
78+
.id(functionItem.getId())
79+
.function(functionCall);
80+
81+
message.addToolCallsItem(toolCall);
82+
}
83+
}
84+
return message;
5485
}
5586
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat;
1010
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop;
1111
import java.math.BigDecimal;
12+
import java.util.Collections;
1213
import java.util.List;
1314
import java.util.Map;
1415
import java.util.Objects;
@@ -149,8 +150,18 @@ public OpenAiChatCompletionRequest(@Nonnull final String message) {
149150
@Tolerate
150151
public OpenAiChatCompletionRequest(
151152
@Nonnull final OpenAiMessage message, @Nonnull final OpenAiMessage... messages) {
153+
this(Lists.asList(message, messages));
154+
}
155+
156+
/**
157+
* Creates an OpenAiChatCompletionPrompt with a list of messages.
158+
*
159+
* @param messages the list of messages to be added to the prompt
160+
*/
161+
@Tolerate
162+
public OpenAiChatCompletionRequest(@Nonnull final List<OpenAiMessage> messages) {
152163
this(
153-
Lists.asList(message, messages),
164+
Collections.unmodifiableList(messages),
154165
null,
155166
null,
156167
null,

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CompletionUsage;
99
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponse;
1010
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner;
11+
import java.util.ArrayList;
1112
import java.util.Objects;
1213
import javax.annotation.Nonnull;
1314
import lombok.RequiredArgsConstructor;
@@ -61,4 +62,27 @@ public String getContent() {
6162

6263
return Objects.requireNonNullElse(getChoice().getMessage().getContent(), "");
6364
}
65+
66+
@Nonnull
67+
public OpenAiAssistantMessage getMessage() {
68+
69+
if (getChoice().getMessage().getToolCalls() == null) {
70+
return OpenAiMessage.assistant(getContent());
71+
}
72+
73+
var contentItems = new ArrayList<OpenAiContentItem>();
74+
if (getContent().isEmpty()) {
75+
contentItems.add(new OpenAiTextItem(getContent()));
76+
}
77+
78+
for (var toolCall : getChoice().getMessage().getToolCalls()) {
79+
contentItems.add(
80+
new OpenAiFunctionCallItem(
81+
toolCall.getId(),
82+
toolCall.getFunction().getName(),
83+
toolCall.getFunction().getArguments()));
84+
}
85+
86+
return new OpenAiAssistantMessage(new OpenAiMessageContent(contentItems));
87+
}
6488
}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiContentItem.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
* @since 1.4.0
99
*/
1010
@Beta
11-
public sealed interface OpenAiContentItem permits OpenAiTextItem, OpenAiImageItem {}
11+
public sealed interface OpenAiContentItem
12+
permits OpenAiTextItem, OpenAiImageItem, OpenAiToolCallItem {}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
import javax.annotation.Nonnull;
4+
5+
import com.google.common.annotations.Beta;
6+
import lombok.AllArgsConstructor;
7+
import lombok.Value;
8+
9+
@Beta
10+
@Value
11+
@AllArgsConstructor(access = lombok.AccessLevel.PACKAGE)
12+
public class OpenAiFunctionCallItem implements OpenAiToolCallItem {
13+
@Nonnull String type = "function";
14+
15+
@Nonnull String id;
16+
@Nonnull String name;
17+
@Nonnull String arguments;
18+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package com.sap.ai.sdk.foundationmodels.openai;
2+
3+
public sealed interface OpenAiToolCallItem extends OpenAiContentItem
4+
permits OpenAiFunctionCallItem {}

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiUtils.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ static ChatCompletionRequestMessage createChatCompletionRequestMessage(
3434
return assistantMessage.createChatCompletionRequestMessage();
3535
} else if (message instanceof OpenAiSystemMessage systemMessage) {
3636
return systemMessage.createChatCompletionRequestMessage();
37-
} else {
37+
} else if(message instanceof OpenAiToolMessage toolMessage) {
38+
return toolMessage.createChatCompletionRequestMessage();
39+
}
40+
else {
3841
throw new IllegalArgumentException("Unknown message type: " + message.getClass());
3942
}
4043
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAIMessageTest.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.sap.ai.sdk.foundationmodels.openai;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Assertions.assertThatNoException;
45

56
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessage;
67
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessageContent;
@@ -13,6 +14,7 @@
1314
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestToolMessageContent;
1415
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessage;
1516
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestUserMessageContent;
17+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ToolCallType;
1618
import java.net.URI;
1719
import java.util.List;
1820
import java.util.stream.Stream;
@@ -194,6 +196,22 @@ void assistantMessageToDto() {
194196
((ChatCompletionRequestAssistantMessageContent.InnerString) requestMessage.getContent())
195197
.value())
196198
.isEqualTo(validText);
199+
200+
var messageWithFunctionCall =
201+
new OpenAiAssistantMessage(
202+
new OpenAiMessageContent(
203+
List.of(new OpenAiFunctionCallItem("id", "name", "arguments"))));
204+
var requestMessageWithFunctionCall =
205+
messageWithFunctionCall.createChatCompletionRequestMessage();
206+
207+
assertThat(requestMessageWithFunctionCall.getToolCalls()).hasSize(1);
208+
assertThat(requestMessageWithFunctionCall.getToolCalls().get(0).getType())
209+
.isEqualTo(ToolCallType.FUNCTION);
210+
assertThat(requestMessageWithFunctionCall.getToolCalls().get(0).getId()).isEqualTo("id");
211+
assertThat(requestMessageWithFunctionCall.getToolCalls().get(0).getFunction().getName())
212+
.isEqualTo("name");
213+
assertThat(requestMessageWithFunctionCall.getToolCalls().get(0).getFunction().getArguments())
214+
.isEqualTo("arguments");
197215
}
198216

199217
@Test
@@ -220,4 +238,35 @@ void throwOnSystemMessageWithImage() {
220238
.hasMessageContaining(
221239
"Unknown content type for class com.sap.ai.sdk.foundationmodels.openai.OpenAiImageItem messages.");
222240
}
241+
242+
@Test
243+
void assistantMessageGetToolCalls() {
244+
var message =
245+
new OpenAiAssistantMessage(
246+
new OpenAiMessageContent(
247+
List.of(
248+
new OpenAiTextItem("text"),
249+
new OpenAiFunctionCallItem("id1", "name1", "arguments1"),
250+
new OpenAiFunctionCallItem("id2", "name2", "arguments2"))));
251+
252+
var toolCalls = message.getToolCalls();
253+
assertThat(toolCalls).hasSize(2);
254+
255+
var functionCallItem1 = (OpenAiFunctionCallItem) toolCalls.get(0);
256+
assertThat(functionCallItem1.getId()).isEqualTo("id1");
257+
assertThat(functionCallItem1.getName()).isEqualTo("name1");
258+
assertThat(functionCallItem1.getArguments()).isEqualTo("arguments1");
259+
260+
var functionCallItem2 = (OpenAiFunctionCallItem) toolCalls.get(1);
261+
assertThat(functionCallItem2.getId()).isEqualTo("id2");
262+
assertThat(functionCallItem2.getName()).isEqualTo("name2");
263+
assertThat(functionCallItem2.getArguments()).isEqualTo("arguments2");
264+
}
265+
266+
@ParameterizedTest
267+
@MethodSource("provideValidTextMessageByRole")
268+
void verifyAllMessageTypesMappedToDto(OpenAiMessage message) {
269+
assertThatNoException()
270+
.isThrownBy(() -> OpenAiUtils.createChatCompletionRequestMessage(message));
271+
}
223272
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequestTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop;
99
import java.math.BigDecimal;
1010
import java.util.List;
11+
12+
import lombok.ToString;
1113
import org.junit.jupiter.api.Test;
1214

1315
class OpenAiChatCompletionRequestTest {

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OpenAiV2Test.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,14 @@ void chatCompletionToolExecution() {
112112
assertThat(content).isNotEmpty();
113113
assertThat(content).contains("°C");
114114
}
115+
116+
@Test
117+
void chatCompletionToolExecutionConvenience() {
118+
final var completion = service.chatCompletionToolExecutionConvenience("Dubai", "°C");
119+
120+
String content = completion.getContent();
121+
122+
assertThat(content).isNotEmpty();
123+
assertThat(content).contains("°C");
124+
}
115125
}

0 commit comments

Comments
 (0)