Skip to content

Commit 47ca557

Browse files
committed
Merge remote-tracking branch 'origin/chatcompletion-for-springopenai2' into chatcompletion-for-springopenai2
# Conflicts: # foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java
2 parents 554e45a + 074d794 commit 47ca557

File tree

10 files changed

+141
-58
lines changed

10 files changed

+141
-58
lines changed

.github/workflows/continuous-integration.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ jobs:
1515

1616
continuous-integration:
1717
runs-on: ubuntu-latest
18+
permissions:
19+
contents: write
1820
steps:
1921

2022
- name: "Checkout repository"

.github/workflows/dependency-test.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ env:
1010
jobs:
1111
fetch-dependency-versions:
1212
runs-on: ubuntu-latest
13+
permissions:
14+
contents: none
1315
outputs:
1416
versions: ${{ steps.fetch-versions.outputs.VERSIONS }}
1517

@@ -39,6 +41,8 @@ jobs:
3941
runs-on: ubuntu-latest
4042
outputs:
4143
cache-key: ${{ steps.cache-build.outputs.cache-key }}
44+
permissions:
45+
contents: read
4246
steps:
4347
- name: "Checkout repository"
4448
uses: actions/checkout@v4
@@ -75,6 +79,8 @@ jobs:
7579
matrix:
7680
version: ${{ fromJson(needs.fetch-dependency-versions.outputs.versions) }}
7781
continue-on-error: true
82+
permissions:
83+
contents: read
7884
steps:
7985
- name: "Checkout repository"
8086
uses: actions/checkout@v4

.github/workflows/deploy-snapshot.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ on:
77

88
jobs:
99
deploy-snapshot:
10-
name: Deploy Snapshot
10+
name: "Deploy Snapshot"
1111
runs-on: ubuntu-latest
1212
timeout-minutes: 15
13+
permissions:
14+
contents: read
1315
steps:
1416
- name: "Checkout Repository"
1517
uses: actions/checkout@v4

.github/workflows/e2e-test.yaml

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ env:
1010

1111
jobs:
1212
end-to-end-tests:
13+
permissions:
14+
contents: read
1315
strategy:
1416
fail-fast: false
1517
matrix:
@@ -41,6 +43,12 @@ jobs:
4143
- name: "Run tests"
4244
id: run_tests
4345
run: |
46+
if [ "${{ matrix.environment }}" = "canary" ]; then
47+
export AICORE_SERVICE_KEY="${{ secrets.AI_CORE_CANARY }}"
48+
else
49+
export AICORE_SERVICE_KEY="${{ secrets.AI_CORE_PRODUCTION }}"
50+
fi
51+
4452
MVN_ARGS="${{ env.MVN_MULTI_THREADED_ARGS }} surefire:test -pl :spring-app -DskipTests=false"
4553
mvn $MVN_ARGS "-Daicore.landscape=${{ matrix.environment }}" | tee mvn_output.log # tee writes to both the console and a file
4654
@@ -60,10 +68,17 @@ jobs:
6068
fi
6169
env:
6270
# See "End-to-end test application instructions" on the README.md to update the secret
63-
AICORE_SERVICE_KEY: ${{ secrets[matrix.secret-name] }}
71+
AI_CORE_PRODUCTION: ${{ secrets.production }}
72+
AI_CORE_CANARY: ${{ secrets.canary }}
6473

6574
- name: "Start Application Locally"
6675
run: |
76+
if [ "${{ matrix.environment }}" = "canary" ]; then
77+
export AICORE_SERVICE_KEY="${{ secrets.AI_CORE_CANARY }}"
78+
else
79+
export AICORE_SERVICE_KEY="${{ secrets.AI_CORE_PRODUCTION }}"
80+
fi
81+
6782
cd sample-code/spring-app
6883
mvn spring-boot:run &
6984
timeout=15
@@ -77,7 +92,8 @@ jobs:
7792
done
7893
env:
7994
# See "End-to-end test application instructions" on the README.md to update the secret
80-
AICORE_SERVICE_KEY: ${{ secrets[matrix.secret-name] }}
95+
AI_CORE_PRODUCTION: ${{ secrets.production }}
96+
AI_CORE_CANARY: ${{ secrets.canary }}
8197

8298
- name: "Health Check"
8399
# print response body with headers to stdout. q:body only O:print -:stdout S:headers

.github/workflows/fosstars-report.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ env:
1313

1414
jobs:
1515
create_fosstars_report:
16-
runs-on: ubuntu-latest
1716
name: "Security rating"
17+
runs-on: ubuntu-latest
18+
permissions:
19+
contents: read
1820
steps:
1921
- name: "Checkout repository"
2022
uses: actions/checkout@v4

.github/workflows/prepare-release.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ jobs:
2828
release-commit: ${{ steps.prepare-release.outputs.RELEASE_COMMIT_ID }}
2929
release-tag: ${{ steps.prepare-release.outputs.TAG_NAME }}
3030
runs-on: ubuntu-latest
31+
permissions:
32+
contents: write
3133
steps:
3234
- name: "Checkout Repository"
3335
uses: actions/checkout@v4
@@ -155,6 +157,9 @@ jobs:
155157
outputs:
156158
pr-url: ${{ steps.create-release-notes-pr.outputs.PR_URL }}
157159
runs-on: ubuntu-latest
160+
permissions:
161+
contents: write
162+
pull-requests: write
158163
steps:
159164
- name: "Checkout Code Repository"
160165
uses: actions/checkout@v4
@@ -233,6 +238,9 @@ jobs:
233238
outputs:
234239
pr-url: ${{ steps.create-code-pr.outputs.PR_URL }}
235240
runs-on: ubuntu-latest
241+
permissions:
242+
contents: write
243+
pull-requests: write
236244
steps:
237245
- name: "Checkout Repository"
238246
uses: actions/checkout@v4

.github/workflows/weekly-spec-update.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ on:
77
jobs:
88
update-all-specs:
99
runs-on: ubuntu-latest
10+
permissions:
11+
contents: write
12+
pull-requests: write
1013

1114
strategy:
1215
matrix:

foundation-models/openai/pom.xml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
</scm>
3939
<properties>
4040
<project.rootdir>${project.basedir}/../../</project.rootdir>
41-
<coverage.complexity>83%</coverage.complexity>
42-
<coverage.line>92%</coverage.line>
43-
<coverage.instruction>90%</coverage.instruction>
44-
<coverage.branch>81%</coverage.branch>
41+
<coverage.complexity>81%</coverage.complexity>
42+
<coverage.line>91%</coverage.line>
43+
<coverage.instruction>89%</coverage.instruction>
44+
<coverage.branch>79%</coverage.branch>
4545
<coverage.method>90%</coverage.method>
4646
<coverage.class>92%</coverage.class>
4747
</properties>

foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java

Lines changed: 88 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
1313
import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall;
1414
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall;
15-
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessage;
1615
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
16+
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner;
1717
import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
1818
import io.vavr.control.Option;
19+
import java.math.BigDecimal;
1920
import java.util.ArrayList;
2021
import java.util.List;
2122
import java.util.Map;
@@ -27,12 +28,14 @@
2728
import org.springframework.ai.chat.messages.AssistantMessage.ToolCall;
2829
import org.springframework.ai.chat.messages.Message;
2930
import org.springframework.ai.chat.messages.ToolResponseMessage;
31+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3032
import org.springframework.ai.chat.model.ChatModel;
3133
import org.springframework.ai.chat.model.ChatResponse;
3234
import org.springframework.ai.chat.model.Generation;
35+
import org.springframework.ai.chat.prompt.ChatOptions;
3336
import org.springframework.ai.chat.prompt.Prompt;
34-
import org.springframework.ai.model.tool.DefaultToolCallingChatOptions;
3537
import org.springframework.ai.model.tool.DefaultToolCallingManager;
38+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3639
import reactor.core.publisher.Flux;
3740

3841
/**
@@ -50,34 +53,40 @@ public class OpenAiChatModel implements ChatModel {
5053
@Override
5154
@Nonnull
5255
public ChatResponse call(@Nonnull final Prompt prompt) {
53-
val openAiRequest = toOpenAiRequest(prompt);
54-
var request = new OpenAiChatCompletionRequest(openAiRequest);
56+
val options = prompt.getOptions();
57+
var request = new OpenAiChatCompletionRequest(extractMessages(prompt));
5558

56-
if ((prompt.getOptions() instanceof DefaultToolCallingChatOptions options)) {
57-
request = request.withTools(extractTools(options));
59+
if (options != null) {
60+
request = extractOptions(request, options);
61+
}
62+
if ((options instanceof ToolCallingChatOptions toolOptions)) {
63+
request = request.withTools(extractTools(toolOptions));
5864
}
5965

6066
val result = client.chatCompletion(request);
6167
val response = new ChatResponse(toGenerations(result));
6268

63-
if (prompt.getOptions() != null
64-
&& isInternalToolExecutionEnabled(prompt.getOptions())
65-
&& response.hasToolCalls()) {
69+
if (options != null && isInternalToolExecutionEnabled(options) && response.hasToolCalls()) {
6670
val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response);
6771
// Send the tool execution result back to the model.
68-
return call(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
72+
return call(new Prompt(toolExecutionResult.conversationHistory(), options));
6973
}
7074
return response;
7175
}
7276

7377
@Override
7478
@Nonnull
7579
public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
76-
val openAiRequest = toOpenAiRequest(prompt);
77-
var request = new OpenAiChatCompletionRequest(openAiRequest);
78-
if ((prompt.getOptions() instanceof DefaultToolCallingChatOptions options)) {
79-
request = request.withTools(extractTools(options));
80+
val options = prompt.getOptions();
81+
var request = new OpenAiChatCompletionRequest(extractMessages(prompt));
82+
83+
if (options != null) {
84+
request = extractOptions(request, options);
85+
}
86+
if ((options instanceof ToolCallingChatOptions toolOptions)) {
87+
request = request.withTools(extractTools(toolOptions));
8088
}
89+
8190
val stream = client.streamChatCompletionDeltas(request);
8291
final Flux<OpenAiChatCompletionDelta> flux =
8392
Flux.generate(
@@ -90,36 +99,16 @@ public Flux<ChatResponse> stream(@Nonnull final Prompt prompt) {
9099
}
91100
return iterator;
92101
});
93-
return flux.map(OpenAiChatModel::toChatResponse);
94-
}
95-
96-
private List<ChatCompletionTool> extractTools(final DefaultToolCallingChatOptions options) {
97-
val tools = new ArrayList<ChatCompletionTool>();
98-
for (val toolCallback : options.getToolCallbacks()) {
99-
val toolDefinition = toolCallback.getToolDefinition();
100-
try {
101-
final Map<String, Object> params =
102-
new ObjectMapper().readValue(toolDefinition.inputSchema(), new TypeReference<>() {});
103-
val toolType = ChatCompletionTool.TypeEnum.FUNCTION;
104-
val toolFunction =
105-
new FunctionObject()
106-
.name(toolDefinition.name())
107-
.description(toolDefinition.description())
108-
.parameters(params);
109-
val tool = new ChatCompletionTool().type(toolType).function(toolFunction);
110-
tools.add(tool);
111-
} catch (JsonProcessingException ignored) {
112-
}
113-
}
114-
return tools;
102+
return flux.map(
103+
delta -> {
104+
val assistantMessage = new AssistantMessage(delta.getDeltaContent(), Map.of());
105+
val metadata =
106+
ChatGenerationMetadata.builder().finishReason(delta.getFinishReason()).build();
107+
return new ChatResponse(List.of(new Generation(assistantMessage, metadata)));
108+
});
115109
}
116110

117-
private static ChatResponse toChatResponse(final OpenAiChatCompletionDelta delta) {
118-
val assistantMessage = new AssistantMessage(delta.getDeltaContent(), Map.of());
119-
return new ChatResponse(List.of(new Generation(assistantMessage)));
120-
}
121-
122-
private List<OpenAiMessage> toOpenAiRequest(final Prompt prompt) {
111+
private List<OpenAiMessage> extractMessages(final Prompt prompt) {
123112
final List<OpenAiMessage> result = new ArrayList<>();
124113
for (final Message message : prompt.getInstructions()) {
125114
switch (message.getMessageType()) {
@@ -152,24 +141,73 @@ private static void addToolMessages(
152141
}
153142

154143
@Nonnull
155-
static List<Generation> toGenerations(@Nonnull final OpenAiChatCompletionResponse result) {
144+
private static List<Generation> toGenerations(
145+
@Nonnull final OpenAiChatCompletionResponse result) {
156146
return result.getOriginalResponse().getChoices().stream()
157-
.map(message -> toGeneration(message.getMessage()))
147+
.map(OpenAiChatModel::toGeneration)
158148
.toList();
159149
}
160150

161151
@Nonnull
162-
static Generation toGeneration(@Nonnull final ChatCompletionResponseMessage choice) {
163-
// no metadata for now
152+
private static Generation toGeneration(
153+
@Nonnull final CreateChatCompletionResponseChoicesInner choice) {
154+
val metadata =
155+
ChatGenerationMetadata.builder().finishReason(choice.getFinishReason().getValue());
156+
metadata.metadata("index", choice.getIndex());
157+
if (choice.getLogprobs() != null && !choice.getLogprobs().getContent().isEmpty()) {
158+
metadata.metadata("logprobs", choice.getLogprobs().getContent());
159+
}
160+
val message = choice.getMessage();
164161
val calls = new ArrayList<ToolCall>();
165-
if (choice.getToolCalls() != null) {
166-
for (final ChatCompletionMessageToolCall c : choice.getToolCalls()) {
162+
if (message.getToolCalls() != null) {
163+
for (final ChatCompletionMessageToolCall c : message.getToolCalls()) {
167164
val fnc = c.getFunction();
168165
calls.add(
169166
new ToolCall(c.getId(), c.getType().getValue(), fnc.getName(), fnc.getArguments()));
170167
}
171168
}
172-
val message = new AssistantMessage(choice.getContent(), Map.of(), calls);
173-
return new Generation(message);
169+
170+
val assistantMessage = new AssistantMessage(message.getContent(), Map.of(), calls);
171+
return new Generation(assistantMessage, metadata.build());
172+
}
173+
174+
private OpenAiChatCompletionRequest extractOptions(
175+
@Nonnull OpenAiChatCompletionRequest request, @Nonnull final ChatOptions options) {
176+
request = request.withStop(options.getStopSequences()).withMaxTokens(options.getMaxTokens());
177+
if (options.getTemperature() != null) {
178+
request = request.withTemperature(BigDecimal.valueOf(options.getTemperature()));
179+
}
180+
if (options.getTopP() != null) {
181+
request = request.withTopP(BigDecimal.valueOf(options.getTopP()));
182+
}
183+
if (options.getPresencePenalty() != null) {
184+
request = request.withPresencePenalty(BigDecimal.valueOf(options.getPresencePenalty()));
185+
}
186+
if (options.getFrequencyPenalty() != null) {
187+
request = request.withFrequencyPenalty(BigDecimal.valueOf(options.getFrequencyPenalty()));
188+
}
189+
return request;
190+
}
191+
192+
private List<ChatCompletionTool> extractTools(final ToolCallingChatOptions options) {
193+
val tools = new ArrayList<ChatCompletionTool>();
194+
for (val toolCallback : options.getToolCallbacks()) {
195+
val toolDefinition = toolCallback.getToolDefinition();
196+
try {
197+
final Map<String, Object> params =
198+
new ObjectMapper().readValue(toolDefinition.inputSchema(), new TypeReference<>() {});
199+
val tool =
200+
new ChatCompletionTool()
201+
.type(ChatCompletionTool.TypeEnum.FUNCTION)
202+
.function(
203+
new FunctionObject()
204+
.name(toolDefinition.name())
205+
.description(toolDefinition.description())
206+
.parameters(params));
207+
tools.add(tool);
208+
} catch (JsonProcessingException ignored) {
209+
}
210+
}
211+
return tools;
174212
}
175213
}

foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModelTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ void testStreamCompletion() throws IOException {
114114
assertThat(deltaList.get(3).getResult().getOutput().getText()).isEqualTo("!");
115115
assertThat(deltaList.get(4).getResult().getOutput().getText()).isEqualTo("");
116116

117+
assertThat(deltaList.get(0).getResult().getMetadata().getFinishReason()).isEqualTo(null);
118+
assertThat(deltaList.get(1).getResult().getMetadata().getFinishReason()).isEqualTo(null);
119+
assertThat(deltaList.get(2).getResult().getMetadata().getFinishReason()).isEqualTo(null);
120+
assertThat(deltaList.get(3).getResult().getMetadata().getFinishReason()).isEqualTo(null);
121+
assertThat(deltaList.get(4).getResult().getMetadata().getFinishReason()).isEqualTo("stop");
122+
117123
Mockito.verify(inputStream, times(1)).close();
118124
}
119125
}

0 commit comments

Comments
 (0)