Skip to content

Commit ed7c6e4

Browse files
finito
1 parent ebc21ff commit ed7c6e4

File tree

14 files changed

+436
-181
lines changed

14 files changed

+436
-181
lines changed

.github/workflows/e2e-test.yaml

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,37 @@ jobs:
3838
MVN_ARGS="${{ env.MVN_MULTI_THREADED_ARGS }} clean install -DskipTests -DskipFormatting"
3939
mvn $MVN_ARGS
4040
41-
- name: "Run tests"
41+
- name: "Run Spring Boot tests"
4242
run: |
4343
MVN_ARGS="${{ env.MVN_MULTI_THREADED_ARGS }} surefire:test -pl :spring-app -DskipTests=false"
4444
mvn $MVN_ARGS "-Daicore.landscape=${{ matrix.environment }}"
4545
env:
4646
# See "End-to-end test application instructions" on the README.md to update the secret
4747
AICORE_SERVICE_KEY: ${{ secrets[matrix.secret-name] }}
4848

49-
- name: "Start Application Locally"
49+
- name: "Run Spring AI tests"
5050
run: |
51-
cd sample-code/spring-app
52-
mvn spring-boot:run &
53-
timeout=15
54-
while ! nc -z localhost 8080; do
55-
sleep 1
56-
timeout=$((timeout - 1))
57-
if [ $timeout -le 0 ]; then
58-
echo "Server did not start within 15 seconds."
59-
exit 1
60-
fi
51+
MVN_ARGS="${{ env.MVN_MULTI_THREADED_ARGS }} surefire:test -pl :spring-ai-app -DskipTests=false"
52+
mvn $MVN_ARGS "-Daicore.landscape=${{ matrix.environment }}"
53+
env:
54+
# See "End-to-end test application instructions" on the README.md to update the secret
55+
AICORE_SERVICE_KEY: ${{ secrets[matrix.secret-name] }}
56+
57+
- name: "Start Applications Locally"
58+
run: |
59+
for project in spring-app spring-ai-app; do
60+
cd sample-code/$project
61+
mvn spring-boot:run &
62+
timeout=15
63+
while ! nc -z localhost 8080; do
64+
sleep 1
65+
timeout=$((timeout - 1))
66+
if [ $timeout -le 0 ]; then
67+
echo "Server did not start within 15 seconds."
68+
exit 1
69+
fi
70+
done
71+
cd ..
6172
done
6273
env:
6374
# See "End-to-end test application instructions" on the README.md to update the secret

orchestration/pom.xml

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
</developers>
3232
<properties>
3333
<project.rootdir>${project.basedir}/../</project.rootdir>
34-
<coverage.complexity>70%</coverage.complexity>
35-
<coverage.line>87%</coverage.line>
36-
<coverage.instruction>88%</coverage.instruction>
37-
<coverage.branch>65%</coverage.branch>
38-
<coverage.method>65%</coverage.method>
34+
<coverage.complexity>78%</coverage.complexity>
35+
<coverage.line>91%</coverage.line>
36+
<coverage.instruction>91%</coverage.instruction>
37+
<coverage.branch>70%</coverage.branch>
38+
<coverage.method>70%</coverage.method>
3939
<coverage.class>100%</coverage.class>
4040
</properties>
4141

@@ -54,7 +54,10 @@
5454
<groupId>com.sap.cloud.sdk.cloudplatform</groupId>
5555
<artifactId>connectivity-apache-httpclient5</artifactId>
5656
</dependency>
57-
57+
<dependency>
58+
<groupId>org.springframework.ai</groupId>
59+
<artifactId>spring-ai-core</artifactId>
60+
</dependency>
5861
<dependency>
5962
<groupId>org.apache.httpcomponents.core5</groupId>
6063
<artifactId>httpcore5</artifactId>
@@ -123,24 +126,6 @@
123126
<artifactId>junit-jupiter-params</artifactId>
124127
<scope>test</scope>
125128
</dependency>
126-
<dependency>
127-
<groupId>org.springframework.boot</groupId>
128-
<artifactId>spring-boot-autoconfigure</artifactId>
129-
<version>3.4.1</version>
130-
<scope>compile</scope>
131-
</dependency>
132-
<dependency>
133-
<groupId>org.springframework.ai</groupId>
134-
<artifactId>spring-ai-core</artifactId>
135-
<version>1.0.0-SNAPSHOT</version>
136-
<scope>compile</scope>
137-
</dependency>
138-
<dependency>
139-
<groupId>org.springframework</groupId>
140-
<artifactId>spring-context</artifactId>
141-
<version>6.2.1</version>
142-
<scope>compile</scope>
143-
</dependency>
144129
</dependencies>
145130

146131
<profiles>

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22

33
import com.sap.ai.sdk.orchestration.AssistantMessage;
44
import com.sap.ai.sdk.orchestration.OrchestrationClient;
5-
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
65
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
7-
6+
import com.sap.ai.sdk.orchestration.SystemMessage;
7+
import com.sap.ai.sdk.orchestration.UserMessage;
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.function.Function;
1111
import javax.annotation.Nonnull;
12-
13-
import com.sap.ai.sdk.orchestration.SystemMessage;
14-
import com.sap.ai.sdk.orchestration.UserMessage;
1512
import lombok.RequiredArgsConstructor;
1613
import lombok.extern.slf4j.Slf4j;
1714
import lombok.val;
@@ -24,13 +21,21 @@
2421
@Slf4j
2522
@RequiredArgsConstructor
2623
public class OrchestrationChatModel implements ChatModel {
27-
@Nonnull private final OrchestrationClient client = new OrchestrationClient();
24+
@Nonnull private OrchestrationClient client;
2825

26+
@Nonnull
2927
@Override
30-
public ChatResponse call(Prompt prompt) {
31-
val orchestrationPrompt = toOrchestrationPrompt(prompt);
32-
val response = client.chatCompletion(orchestrationPrompt, ((OrchestrationChatOptions) prompt.getOptions()).getConfig());
33-
return OrchestrationChatResponse.fromOrchestrationResponse(response.getOriginalResponse());
28+
public ChatResponse call(@Nonnull final Prompt prompt) {
29+
30+
if (prompt.getOptions() != null
31+
&& prompt.getOptions() instanceof OrchestrationChatOptions options) {
32+
33+
val orchestrationPrompt = toOrchestrationPrompt(prompt);
34+
val response = client.chatCompletion(orchestrationPrompt, options.getConfig());
35+
return OrchestrationChatResponse.fromOrchestrationResponse(response.getOriginalResponse());
36+
}
37+
throw new IllegalArgumentException(
38+
"Please add OrchestrationChatOptions to the Prompt: new Prompt(\"message\", new OrchestrationChatOptions(config))");
3439
}
3540

3641
@Nonnull

orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java

Lines changed: 144 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
package com.sap.ai.sdk.orchestration.spring;
22

3+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.FREQUENCY_PENALTY;
4+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.MAX_TOKENS;
5+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.PRESENCE_PENALTY;
6+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.TEMPERATURE;
7+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.TOP_P;
8+
39
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
410
import com.sap.ai.sdk.orchestration.model.LLMModuleConfig;
5-
11+
import java.util.ArrayList;
612
import java.util.LinkedHashMap;
713
import java.util.List;
8-
import java.util.Map;
914
import java.util.Objects;
1015
import javax.annotation.Nonnull;
1116
import javax.annotation.Nullable;
1217
import lombok.AccessLevel;
1318
import lombok.Data;
1419
import lombok.Getter;
1520
import lombok.Setter;
21+
import lombok.val;
1622
import org.springframework.ai.chat.prompt.ChatOptions;
1723

1824
/** Configuration to be used for orchestration requests. */
@@ -21,81 +27,194 @@
2127
@Setter(AccessLevel.NONE)
2228
public class OrchestrationChatOptions implements ChatOptions {
2329

24-
@Getter(AccessLevel.PUBLIC)
25-
@Nonnull
26-
private Map<String, String> templateParameters = Map.of();
27-
2830
@Getter(AccessLevel.PUBLIC)
2931
@Setter(AccessLevel.PUBLIC)
3032
@Nonnull
3133
OrchestrationModuleConfig config;
3234

33-
// region satisfy the ChatOptions interface, delegating to the LLM config
34-
@Nullable
35+
/**
36+
* Returns the model to use for the chat.
37+
*
38+
* @return the model to use for the chat
39+
* @see com.sap.ai.sdk.orchestration.OrchestrationAiModel
40+
*/
41+
@Nonnull
3542
@Override
3643
public String getModel() {
3744
return getLlmConfigNonNull().getModelName();
3845
}
3946

40-
@Nullable
41-
String getModelVersion() {
47+
private void setModel(@Nonnull final String model) {
48+
getLlmConfigNonNull().setModelName(model);
49+
}
50+
51+
/**
52+
* Returns the model version to use for the chat. "latest" by default.
53+
*
54+
* @return the model version to use for the chat.
55+
*/
56+
@Nonnull
57+
public String getModelVersion() {
4258
return getLlmConfigNonNull().getModelVersion();
4359
}
4460

61+
private void setModelVersion(@Nonnull final String modelVersion) {
62+
getLlmConfigNonNull().setModelVersion(modelVersion);
63+
}
64+
65+
/**
66+
* Returns the frequency penalty to use for the chat.
67+
*
68+
* @return the frequency penalty to use for the chat
69+
*/
4570
@Nullable
4671
@Override
4772
public Double getFrequencyPenalty() {
48-
return getLlmConfigParam("frequencyPenalty", Double.class);
73+
return getLlmConfigParam(FREQUENCY_PENALTY.getName());
4974
}
5075

76+
private void setFrequencyPenalty(@Nonnull final Double frequencyPenalty) {
77+
setLlmConfigParam(FREQUENCY_PENALTY.getName(), frequencyPenalty);
78+
}
79+
80+
/**
81+
* Returns the maximum number of tokens to use for the chat.
82+
*
83+
* @return the maximum number of tokens to use for the chat
84+
*/
5185
@Nullable
5286
@Override
5387
public Integer getMaxTokens() {
54-
return getLlmConfigParam("maxTokens", Integer.class);
88+
return getLlmConfigParam(MAX_TOKENS.getName());
89+
}
90+
91+
private void setMaxTokens(@Nonnull final Integer maxTokens) {
92+
setLlmConfigParam(MAX_TOKENS.getName(), maxTokens);
5593
}
5694

95+
/**
96+
* Returns the presence penalty to use for the chat.
97+
*
98+
* @return the presence penalty to use for the chat
99+
*/
57100
@Nullable
58101
@Override
59102
public Double getPresencePenalty() {
60-
return getLlmConfigParam("presencePenalty", Double.class);
103+
return getLlmConfigParam(PRESENCE_PENALTY.getName());
61104
}
62105

63-
@SuppressWarnings("unchecked")
106+
private void setPresencePenalty(@Nonnull final Double presencePenalty) {
107+
setLlmConfigParam(PRESENCE_PENALTY.getName(), presencePenalty);
108+
}
109+
110+
/**
111+
* Returns the stop sequences to use for the chat.
112+
*
113+
* @return the stop sequences to use for the chat
114+
*/
64115
@Nullable
65116
@Override
66117
public List<String> getStopSequences() {
67-
return getLlmConfigParam("stopSequences", List.class);
118+
return getLlmConfigParam("stop_sequences");
119+
}
120+
121+
private void setStopSequences(@Nonnull final List<String> stopSequences) {
122+
setLlmConfigParam("stop_sequences", stopSequences);
68123
}
69124

125+
/**
126+
* Returns the temperature to use for the chat.
127+
*
128+
* @return the temperature to use for the chat
129+
*/
70130
@Nullable
71131
@Override
72132
public Double getTemperature() {
73-
return getLlmConfigParam("temperature", Double.class);
133+
return getLlmConfigParam(TEMPERATURE.getName());
74134
}
75135

136+
private void setTemperature(@Nonnull final Double temperature) {
137+
setLlmConfigParam(TEMPERATURE.getName(), temperature);
138+
}
139+
140+
/**
141+
* Returns the top K to use for the chat.
142+
*
143+
* @return the top K to use for the chat
144+
*/
76145
@Nullable
77146
@Override
78147
public Integer getTopK() {
79-
return getLlmConfigParam("topK", Integer.class);
148+
return getLlmConfigParam("top_k");
149+
}
150+
151+
private void setTopK(@Nonnull final Integer topK) {
152+
setLlmConfigParam("top_k", topK);
80153
}
81154

155+
/**
156+
* Returns the top P to use for the chat.
157+
*
158+
* @return the top P to use for the chat
159+
*/
82160
@Nullable
83161
@Override
84162
public Double getTopP() {
85-
return getLlmConfigParam("topP", Double.class);
163+
return getLlmConfigParam(TOP_P.getName());
86164
}
87165

166+
private void setTopP(@Nonnull final Double topP) {
167+
setLlmConfigParam(TOP_P.getName(), topP);
168+
}
169+
170+
/**
171+
* Returns a copy of this {@link OrchestrationChatOptions}.
172+
*
173+
* @return a copy of this {@link OrchestrationChatOptions}
174+
*/
175+
@SuppressWarnings("unchecked") // The same suppress is in DefaultChatOptions
176+
@Nonnull
88177
@Override
89-
public OrchestrationChatOptions copy() {
90-
var copy = new OrchestrationChatOptions(config);
91-
copy.templateParameters.putAll(this.templateParameters);
92-
return copy;
178+
public <T extends ChatOptions> T copy() {
179+
val copy = new OrchestrationChatOptions(config);
180+
copy.setModel(this.getModel());
181+
copy.setModelVersion(this.getModelVersion());
182+
if (getFrequencyPenalty() != null) {
183+
copy.setFrequencyPenalty(this.getFrequencyPenalty());
184+
}
185+
if (getMaxTokens() != null) {
186+
copy.setMaxTokens(this.getMaxTokens());
187+
}
188+
if (getPresencePenalty() != null) {
189+
copy.setPresencePenalty(this.getPresencePenalty());
190+
}
191+
if (getStopSequences() != null) {
192+
copy.setStopSequences(new ArrayList<>(this.getStopSequences()));
193+
}
194+
if (getTemperature() != null) {
195+
copy.setTemperature(this.getTemperature());
196+
}
197+
if (getTopK() != null) {
198+
copy.setTopK(this.getTopK());
199+
}
200+
if (getTopP() != null) {
201+
copy.setTopP(this.getTopP());
202+
}
203+
return (T) copy;
93204
}
94205

95206
@SuppressWarnings("unchecked")
96-
@Nonnull
97-
private <T> T getLlmConfigParam(@Nonnull final String param, @Nonnull final Class<T> defaultValue) {
98-
return ((LinkedHashMap<String, T>) getLlmConfigNonNull().getModelParams()).get(param);
207+
@Nullable
208+
private <T> T getLlmConfigParam(@Nonnull final String param) {
209+
if (getLlmConfigNonNull().getModelParams() instanceof LinkedHashMap) {
210+
return ((LinkedHashMap<String, T>) getLlmConfigNonNull().getModelParams()).get(param);
211+
}
212+
return null;
213+
}
214+
215+
@SuppressWarnings("unchecked")
216+
private <T> void setLlmConfigParam(@Nonnull final String param, @Nonnull final T value) {
217+
((LinkedHashMap<String, T>) getLlmConfigNonNull().getModelParams()).put(param, value);
99218
}
100219

101220
@Nonnull

0 commit comments

Comments
 (0)