Skip to content

Commit 46f4ba4

Browse files
WiP
1 parent fcab936 commit 46f4ba4

File tree

16 files changed

+860
-0
lines changed

16 files changed

+860
-0
lines changed

orchestration/pom.xml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,24 @@
123123
<artifactId>junit-jupiter-params</artifactId>
124124
<scope>test</scope>
125125
</dependency>
126+
<dependency>
127+
<groupId>org.springframework.boot</groupId>
128+
<artifactId>spring-boot-autoconfigure</artifactId>
129+
<version>3.4.1</version>
130+
<scope>compile</scope>
131+
</dependency>
132+
<dependency>
133+
<groupId>org.springframework.ai</groupId>
134+
<artifactId>spring-ai-core</artifactId>
135+
<version>1.0.0-SNAPSHOT</version>
136+
<scope>compile</scope>
137+
</dependency>
138+
<dependency>
139+
<groupId>org.springframework</groupId>
140+
<artifactId>spring-context</artifactId>
141+
<version>6.2.1</version>
142+
<scope>compile</scope>
143+
</dependency>
126144
</dependencies>
127145

128146
<profiles>
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import com.sap.ai.sdk.orchestration.AssistantMessage;
4+
import com.sap.ai.sdk.orchestration.OrchestrationClient;
5+
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
6+
import com.sap.ai.sdk.orchestration.OrchestrationPrompt;
7+
8+
import java.util.List;
9+
import java.util.Map;
10+
import java.util.function.Function;
11+
import javax.annotation.Nonnull;
12+
13+
import com.sap.ai.sdk.orchestration.SystemMessage;
14+
import com.sap.ai.sdk.orchestration.UserMessage;
15+
import lombok.RequiredArgsConstructor;
16+
import lombok.extern.slf4j.Slf4j;
17+
import lombok.val;
18+
import org.springframework.ai.chat.messages.Message;
19+
import org.springframework.ai.chat.model.ChatModel;
20+
import org.springframework.ai.chat.model.ChatResponse;
21+
import org.springframework.ai.chat.prompt.Prompt;
22+
23+
/** Spring AI integration for the orchestration service. */
24+
@Slf4j
25+
@RequiredArgsConstructor
26+
public class OrchestrationChatModel implements ChatModel {
27+
@Nonnull private final OrchestrationClient client = new OrchestrationClient();
28+
@Nonnull private final OrchestrationModuleConfig config;
29+
30+
@Override
31+
public ChatResponse call(Prompt prompt) {
32+
val orchestrationPrompt = toOrchestrationPrompt(prompt);
33+
val response = client.chatCompletion(orchestrationPrompt, config);
34+
return OrchestrationChatResponse.fromOrchestrationResponse(response.getOriginalResponse());
35+
}
36+
37+
@Nonnull
38+
private OrchestrationPrompt toOrchestrationPrompt(@Nonnull final Prompt prompt) {
39+
val messages = toOrchestrationMessages(prompt.getInstructions());
40+
return new OrchestrationPrompt(Map.of(), messages);
41+
}
42+
43+
// endregion
44+
45+
@Nonnull
46+
private static com.sap.ai.sdk.orchestration.Message[] toOrchestrationMessages(
47+
@Nonnull final List<Message> messages) {
48+
final Function<Message, com.sap.ai.sdk.orchestration.Message> mapper =
49+
msg ->
50+
switch (msg.getMessageType()) {
51+
case SYSTEM:
52+
yield new SystemMessage(msg.getText());
53+
case USER:
54+
yield new UserMessage(msg.getText());
55+
case ASSISTANT:
56+
yield new AssistantMessage(msg.getText());
57+
case TOOL:
58+
throw new IllegalArgumentException("Tool messages are not supported");
59+
};
60+
return messages.stream().map(mapper).toArray(com.sap.ai.sdk.orchestration.Message[]::new);
61+
}
62+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
4+
import com.sap.ai.sdk.orchestration.model.LLMModuleConfig;
5+
6+
import java.util.LinkedHashMap;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.Objects;
10+
import javax.annotation.Nonnull;
11+
import javax.annotation.Nullable;
12+
import lombok.AccessLevel;
13+
import lombok.Data;
14+
import lombok.Getter;
15+
import lombok.Setter;
16+
import org.springframework.ai.chat.prompt.ChatOptions;
17+
18+
/** Configuration to be used for orchestration requests. */
19+
@Data
20+
@Getter(AccessLevel.NONE)
21+
@Setter(AccessLevel.NONE)
22+
public class OrchestrationChatOptions implements ChatOptions {
23+
24+
@Getter(AccessLevel.PUBLIC)
25+
@Nonnull
26+
private Map<String, String> templateParameters = Map.of();
27+
28+
@Getter(AccessLevel.PUBLIC)
29+
@Setter(AccessLevel.PUBLIC)
30+
@Nonnull
31+
OrchestrationModuleConfig config = new OrchestrationModuleConfig();
32+
33+
// region satisfy the ChatOptions interface, delegating to the LLM config
34+
@Nullable
35+
@Override
36+
public String getModel() {
37+
return getLlmConfigNonNull().getModelName();
38+
}
39+
40+
@Nullable
41+
String getModelVersion() {
42+
return getLlmConfigNonNull().getModelVersion();
43+
}
44+
45+
@Nullable
46+
@Override
47+
public Double getFrequencyPenalty() {
48+
return getLlmConfigParam("frequencyPenalty", Double.class);
49+
}
50+
51+
@Nullable
52+
@Override
53+
public Integer getMaxTokens() {
54+
return getLlmConfigParam("maxTokens", Integer.class);
55+
}
56+
57+
@Nullable
58+
@Override
59+
public Double getPresencePenalty() {
60+
return getLlmConfigParam("presencePenalty", Double.class);
61+
}
62+
63+
@SuppressWarnings("unchecked")
64+
@Nullable
65+
@Override
66+
public List<String> getStopSequences() {
67+
return getLlmConfigParam("stopSequences", List.class);
68+
}
69+
70+
@Nullable
71+
@Override
72+
public Double getTemperature() {
73+
return getLlmConfigParam("temperature", Double.class);
74+
}
75+
76+
@Nullable
77+
@Override
78+
public Integer getTopK() {
79+
return getLlmConfigParam("topK", Integer.class);
80+
}
81+
82+
@Nullable
83+
@Override
84+
public Double getTopP() {
85+
return getLlmConfigParam("topP", Double.class);
86+
}
87+
88+
@Override
89+
public OrchestrationChatOptions copy() {
90+
var copy = new OrchestrationChatOptions();
91+
copy.config = this.config;
92+
copy.templateParameters.putAll(this.templateParameters);
93+
return copy;
94+
}
95+
96+
@SuppressWarnings("unchecked")
97+
@Nonnull
98+
private <T> T getLlmConfigParam(@Nonnull final String param, @Nonnull final Class<T> defaultValue) {
99+
return ((LinkedHashMap<String, T>) getLlmConfigNonNull().getModelParams()).get(param);
100+
}
101+
102+
@Nonnull
103+
private LLMModuleConfig getLlmConfigNonNull() {
104+
return Objects.requireNonNull(config.getLlmConfig(), "LLM config is not set");
105+
}
106+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import com.sap.ai.sdk.orchestration.model.CompletionPostResponse;
4+
import com.sap.ai.sdk.orchestration.model.LLMChoice;
5+
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
6+
import com.sap.ai.sdk.orchestration.model.ModuleResults;
7+
import com.sap.ai.sdk.orchestration.model.TokenUsage;
8+
import java.util.HashMap;
9+
import java.util.List;
10+
import java.util.Map;
11+
import javax.annotation.Nonnull;
12+
import lombok.EqualsAndHashCode;
13+
import lombok.Value;
14+
import lombok.val;
15+
import org.springframework.ai.chat.messages.AssistantMessage;
16+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
17+
import org.springframework.ai.chat.metadata.DefaultUsage;
18+
import org.springframework.ai.chat.model.ChatResponse;
19+
import org.springframework.ai.chat.model.Generation;
20+
21+
@Value
22+
@EqualsAndHashCode(callSuper = true)
23+
public class OrchestrationChatResponse extends ChatResponse {
24+
25+
private OrchestrationChatResponse(
26+
@Nonnull final List<Generation> generations,
27+
@Nonnull final ChatResponseMetadata metadata) {
28+
super(generations, metadata);
29+
}
30+
31+
@Nonnull
32+
public static OrchestrationChatResponse fromOrchestrationResponse(
33+
@Nonnull final CompletionPostResponse response) {
34+
val result = (LLMModuleResultSynchronous) response.getOrchestrationResult();
35+
val generations = toGenerations(result);
36+
val metadata = toChatResponseMetadata(result);
37+
return new OrchestrationChatResponse(generations, metadata);
38+
}
39+
40+
@Nonnull
41+
static List<Generation> toGenerations(@Nonnull final LLMModuleResultSynchronous result) {
42+
return result.getChoices().stream()
43+
.map(OrchestrationChatResponse::toAssistantMessage)
44+
.map(Generation::new)
45+
.toList();
46+
}
47+
48+
@Nonnull
49+
static AssistantMessage toAssistantMessage(@Nonnull final LLMChoice choice) {
50+
Map<String, Object> metadata = new HashMap<>();
51+
metadata.put("finish_reason", choice.getFinishReason());
52+
metadata.put("index", choice.getIndex());
53+
if (!choice.getLogprobs().isEmpty()) {
54+
metadata.put("logprobs", choice.getLogprobs());
55+
}
56+
return new AssistantMessage(choice.getMessage().getContent(), metadata);
57+
}
58+
59+
@Nonnull
60+
static ChatResponseMetadata toChatResponseMetadata(
61+
@Nonnull final LLMModuleResultSynchronous orchestrationResult) {
62+
var metadataBuilder = ChatResponseMetadata.builder();
63+
64+
metadataBuilder
65+
.id(orchestrationResult.getId())
66+
.model(orchestrationResult.getModel())
67+
.keyValue("object", orchestrationResult.getObject())
68+
.keyValue("created", orchestrationResult.getCreated())
69+
.usage(toDefaultUsage(orchestrationResult.getUsage()));
70+
71+
return metadataBuilder.build();
72+
}
73+
74+
@Nonnull
75+
private static DefaultUsage toDefaultUsage(@Nonnull final TokenUsage usage) {
76+
return new DefaultUsage(
77+
usage.getPromptTokens().longValue(),
78+
usage.getCompletionTokens().longValue(),
79+
usage.getTotalTokens().longValue());
80+
}
81+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GEMINI_1_5_FLASH;
4+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.TEMPERATURE;
5+
import static org.assertj.core.api.Assertions.assertThat;
6+
7+
import org.junit.jupiter.api.BeforeEach;
8+
import org.junit.jupiter.api.Test;
9+
10+
class OrchestrationChatOptionsTest {
11+
12+
OrchestrationChatOptions opts;
13+
14+
@BeforeEach
15+
void setUp() {
16+
opts = new OrchestrationChatOptions();
17+
}
18+
19+
@Test
20+
void testHyperParameters() {
21+
var llm = GEMINI_1_5_FLASH.withParam(TEMPERATURE, 0.5).withParam("maxTokens", 100);
22+
var cac = opts.withLlmConfig(llm);
23+
opts.config = cac;
24+
25+
assertThat(opts.getTemperature()).isEqualTo(0.5);
26+
assertThat(opts.getMaxTokens()).isEqualTo(100);
27+
}
28+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.sap.ai.sdk.orchestration.spring;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import com.sap.ai.sdk.orchestration.model.ChatMessage;
6+
import com.sap.ai.sdk.orchestration.model.LLMChoice;
7+
import com.sap.ai.sdk.orchestration.model.LLMModuleResult;
8+
import com.sap.ai.sdk.orchestration.model.LLMModuleResultSynchronous;
9+
import com.sap.ai.sdk.orchestration.model.TokenUsage;
10+
import java.util.List;
11+
import org.junit.jupiter.api.Test;
12+
import org.springframework.ai.chat.messages.AssistantMessage;
13+
14+
class OrchestrationChatResponseTest {
15+
16+
@Test
17+
void testToAssistantMessage() {
18+
var choice =
19+
LLMChoice.create()
20+
.index(0)
21+
.message(ChatMessage.create().role("assistant").content("Hello, world!"))
22+
.finishReason("stop");
23+
24+
AssistantMessage message = OrchestrationChatResponse.toAssistantMessage(choice);
25+
26+
assertThat(message.getContent()).isEqualTo("Hello, world!");
27+
assertThat(message.getMetadata()).containsEntry("finish_reason", "stop");
28+
assertThat(message.getMetadata()).containsEntry("index", 0);
29+
}
30+
31+
@Test
32+
void testToChatResponseMetadata() {
33+
var moduleResult =
34+
LLMModuleResultSynchronous.create()
35+
.id("test-id")
36+
._object("test-object")
37+
.created(123456789)
38+
.model("test-model")
39+
.choices(List.of())
40+
.usage(TokenUsage.create().completionTokens(20).promptTokens(10).totalTokens(30));
41+
42+
var metadata = OrchestrationChatResponse.toChatResponseMetadata(moduleResult);
43+
44+
assertThat(metadata.getId()).isEqualTo("test-id");
45+
assertThat(metadata.getModel()).isEqualTo("test-model");
46+
assertThat(metadata.<String>get("object")).isEqualTo("test-object");
47+
assertThat(metadata.<Integer>get("created")).isEqualTo(123456789);
48+
49+
var usage = metadata.getUsage();
50+
51+
assertThat(usage.getPromptTokens()).isEqualTo(10L);
52+
assertThat(usage.getGenerationTokens()).isEqualTo(20L);
53+
assertThat(usage.getTotalTokens()).isEqualTo(30L);
54+
}
55+
}

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@
165165
<artifactId>spring-app</artifactId>
166166
<version>${project.version}</version>
167167
</dependency>
168+
<dependency>
169+
<groupId>com.sap.ai.sdk.app</groupId>
170+
<artifactId>spring-ai-app</artifactId>
171+
<version>${project.version}</version>
172+
</dependency>
168173
</dependencies>
169174
</dependencyManagement>
170175
<dependencies>
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Sample Code and E2E Test
2+
3+
![e2e-test](https://github.com/SAP/ai-sdk-java/actions/workflows/e2e-test.yaml/badge.svg)
4+
5+
Sample code to demonstrate the usage of the SAP AI SDK.
6+
Also used as basis for running E2E tests.
7+
8+
## Build, Run, Deploy Locally
9+
10+
Build the project with:
11+
12+
1. `mvn compile`
13+
2. [Download a service key for your AI Core service instance](../../README.md#set-credentials-as-dedicated-environment-variable)
14+
3. Create the environment variable `AICORE_SERVICE_KEY`
15+
4. Run the application with `mvn spring-boot:run`
16+
5. [See all available endpoints](localhost:8080)
17+
18+
## Run the E2E Test
19+
20+
Trigger the [GitHub Action](https://github.com/SAP/ai-sdk-java/actions/workflows/e2e-test.yml).

0 commit comments

Comments
 (0)