Skip to content

Commit c52720e

Browse files
committed
--> still having testToolCallingWithoutExecution() in SpringAiOpenAiTest.java failing.
--> still fix of null of message.getText() in toAssistantMessage() method in OpenAiChatModel.java pending.
1 parent 84546c0 commit c52720e

File tree

1 file changed

+39
-45
lines changed
  • foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring

1 file changed

+39
-45
lines changed

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.sap.ai.sdk.foundationmodels.openai.spring;
22

3+
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage.tool;
34
import static org.springframework.ai.model.tool.ToolCallingChatOptions.isInternalToolExecutionEnabled;
45

56
import com.sap.ai.sdk.foundationmodels.openai.OpenAiAssistantMessage;
@@ -11,17 +12,18 @@
1112
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessageContent;
1213
import com.sap.ai.sdk.foundationmodels.openai.OpenAiTextItem;
1314
import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall;
15+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall;
1416
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessage;
15-
import io.vavr.control.Option;
17+
18+
import java.util.ArrayList;
1619
import java.util.List;
1720
import java.util.Map;
18-
import java.util.stream.Stream;
21+
import java.util.function.Function;
1922
import javax.annotation.Nonnull;
2023
import lombok.RequiredArgsConstructor;
2124
import lombok.val;
22-
import org.springframework.ai.chat.messages.AssistantMessage;
25+
import org.springframework.ai.chat.messages.*;
2326
import org.springframework.ai.chat.messages.AssistantMessage.ToolCall;
24-
import org.springframework.ai.chat.messages.ToolResponseMessage;
2527
import org.springframework.ai.chat.model.ChatModel;
2628
import org.springframework.ai.chat.model.ChatResponse;
2729
import org.springframework.ai.chat.model.Generation;
@@ -47,10 +49,13 @@ public ChatResponse call(@Nonnull final Prompt prompt) {
4749
throw new IllegalArgumentException(
4850
"Please add OpenAiChatOptions to the Prompt: new Prompt(\"message\", new OpenAiChatOptions(config))");
4951
}
52+
System.out.println("I entered OpenAiChatModel.call() with tools: " + options.getTools());
53+
val openAiRequest = toOpenAiRequest(prompt);
54+
val request = new OpenAiChatCompletionRequest(openAiRequest).withTools(options.getTools());
55+
val result = client.chatCompletion(request);
56+
val response = new ChatResponse(toGenerations(result));
5057

51-
val request =
52-
new OpenAiChatCompletionRequest(toOpenAiRequest(prompt)).withTools(options.getTools());
53-
val response = new ChatResponse(toGenerations(client.chatCompletion(request)));
58+
System.out.println("I entered OpenAiChatModel.call() with response: " + response);
5459

5560
if (isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {
5661
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response);
@@ -61,44 +66,33 @@ public ChatResponse call(@Nonnull final Prompt prompt) {
6166
}
6267

6368
private List<OpenAiMessage> toOpenAiRequest(final Prompt prompt) {
64-
return prompt.getInstructions().stream()
65-
.flatMap(
66-
message ->
67-
switch (message.getMessageType()) {
68-
case USER ->
69-
Stream.of(
70-
OpenAiMessage.user(
71-
Option.of(message.getText()).getOrElse(message.getText())));
72-
case ASSISTANT -> {
73-
val assistantMessage = (AssistantMessage) message;
74-
yield Stream.of(
75-
assistantMessage.hasToolCalls()
76-
? new OpenAiAssistantMessage(
77-
new OpenAiMessageContent(
78-
List.of(
79-
new OpenAiTextItem(
80-
Option.of(message.getText())
81-
.getOrElse(message.getText())))),
82-
assistantMessage.getToolCalls().stream()
83-
.map(
84-
toolCall ->
85-
(OpenAiToolCall)
86-
new OpenAiFunctionCall(
87-
toolCall.id(),
88-
toolCall.name(),
89-
toolCall.arguments()))
90-
.toList())
91-
: new OpenAiAssistantMessage(
92-
Option.of(message.getText()).getOrElse(message.getText())));
93-
}
94-
case SYSTEM -> Stream.of(OpenAiMessage.system(message.getText()));
95-
case TOOL -> {
96-
val responses = ((ToolResponseMessage) message).getResponses();
97-
yield responses.stream()
98-
.map(resp -> OpenAiMessage.tool(resp.responseData(), resp.id()));
99-
}
100-
})
101-
.toList();
69+
final List<OpenAiMessage> result = new ArrayList<>();
70+
for (final Message message : prompt.getInstructions()) {
71+
//if(((message.getMessageType() == MessageType.USER || message.getMessageType() ==MessageType.ASSISTANT || message.getMessageType() ==MessageType.SYSTEM ) && message.getText() != null) || (message.getMessageType() == MessageType.TOOL)) {
72+
switch (message.getMessageType()) {
73+
case USER -> result.add(OpenAiMessage.user(message.getText()));
74+
case ASSISTANT -> result.add(toAssistantMessage((AssistantMessage) message));
75+
case SYSTEM -> result.add(OpenAiMessage.system(message.getText()));
76+
case TOOL -> result.addAll(toToolMessages((ToolResponseMessage) message));
77+
}
78+
//}
79+
}
80+
return result;
81+
}
82+
83+
private static OpenAiAssistantMessage toAssistantMessage(AssistantMessage message) {
84+
if (!message.hasToolCalls()) {
85+
return OpenAiMessage.assistant(message.getText());
86+
}
87+
final Function<ToolCall, OpenAiToolCall> callTranslate =
88+
toolCall -> new OpenAiFunctionCall(toolCall.id(), toolCall.name(), toolCall.arguments());
89+
val content = new OpenAiMessageContent(List.of(new OpenAiTextItem(message.getText())));
90+
val calls = message.getToolCalls().stream().map(callTranslate).toList();
91+
return new OpenAiAssistantMessage(content, calls);
92+
}
93+
94+
private static List<? extends OpenAiMessage> toToolMessages(ToolResponseMessage message) {
95+
return message.getResponses().stream().map(r -> tool(r.responseData(), r.id())).toList();
10296
}
10397

10498
@Nonnull

0 commit comments

Comments
 (0)