Skip to content

Commit d2c65bb

Browse files
CharlesDuboisSAPbot-sdk-jsJonas-IsrMatKuhr
authored
Added Orchestration LLM Config Convenience (#152)
* Added Orchestration LLM Config Convenience * Fixed orchestrationModelAvailability * Fixed orchestrationModelAvailability * Revert "Fixed orchestrationModelAvailability" This reverts commit ff88704. * Re-added all models, removed test * Formatting * Work in progress * Work in progress * Work in progress * Work in Progress. * Merged main branch. * Matthias' review * renaming * Formatting * change * 2nd review * Update orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java * Formatting * Fix * Fix --------- Co-authored-by: SAP Cloud SDK Bot <[email protected]> Co-authored-by: Jonas Israel <[email protected]> Co-authored-by: Matthias Kuhr <[email protected]> Co-authored-by: Matthias Kuhr <[email protected]>
1 parent cf661ad commit d2c65bb

File tree

8 files changed

+167
-32
lines changed

8 files changed

+167
-32
lines changed

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ To use the Orchestration service, create a client and a configuration object:
7777
var client = new OrchestrationClient();
7878

7979
var config = new OrchestrationModuleConfig()
80-
.withLlmConfig(LLMModuleConfig.create().modelName("gpt-35-turbo").modelParams(Map.of()));
80+
.withLlmConfig(OrchestrationAiModel.GPT_4O);
8181
```
8282

8383
Please also refer to [our sample code](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java) for this and all following code examples.
@@ -214,16 +214,16 @@ In this example, the input will be masked before the call to the LLM. Note that
214214

215215
### Set model parameters
216216

217-
Change your LLM module configuration to add model parameters:
217+
Change your LLM configuration to add model parameters:
218218

219219
```java
220-
var llmConfig =
221-
LLMModuleConfig.create()
222-
.modelName("gpt-35-turbo")
223-
.modelParams(
220+
OrchestrationAiModel customGPT4O =
221+
OrchestrationAiModel.GPT_4O
222+
.withModelParams(
224223
Map.of(
225224
"max_tokens", 50,
226225
"temperature", 0.1,
227226
"frequency_penalty", 0,
228-
"presence_penalty", 0));
227+
"presence_penalty", 0))
228+
.withModelVersion("2024-05-13");
229229
```
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
4+
import java.util.Map;
5+
import javax.annotation.Nonnull;
6+
import lombok.AllArgsConstructor;
7+
import lombok.Value;
8+
import lombok.With;
9+
10+
/** Large language models available in Orchestration. */
11+
@Value
12+
@With
13+
@AllArgsConstructor
14+
public class OrchestrationAiModel {
15+
/** The name of the model */
16+
String modelName;
17+
18+
/**
19+
* Optional parameters on this model.
20+
*
21+
* <pre>{@code
22+
* Map.of(
23+
* "max_tokens", 50,
24+
* "temperature", 0.1,
25+
* "frequency_penalty", 0,
26+
* "presence_penalty", 0)
27+
* }</pre>
28+
*/
29+
Map<String, Object> modelParams;
30+
31+
/** The version of the model, defaults to "latest". */
32+
String modelVersion;
33+
34+
/** IBM Granite 13B chat completions model */
35+
public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
36+
new OrchestrationAiModel("ibm--granite-13b-chat");
37+
38+
/** MistralAI Mistral Large Instruct model */
39+
public static final OrchestrationAiModel MISTRAL_LARGE_INSTRUCT =
40+
new OrchestrationAiModel("mistralai--mistral-large-instruct");
41+
42+
/** MistralAI Mixtral 8x7B Instruct v01 model */
43+
public static final OrchestrationAiModel MIXTRAL_8X7B_INSTRUCT_V01 =
44+
new OrchestrationAiModel("mistralai--mixtral-8x7b-instruct-v01");
45+
46+
/** Meta Llama3 70B Instruct model */
47+
public static final OrchestrationAiModel LLAMA3_70B_INSTRUCT =
48+
new OrchestrationAiModel("meta--llama3-70b-instruct");
49+
50+
/** Meta Llama3.1 70B Instruct model */
51+
public static final OrchestrationAiModel LLAMA3_1_70B_INSTRUCT =
52+
new OrchestrationAiModel("meta--llama3.1-70b-instruct");
53+
54+
/** Anthropic Claude 3 Sonnet model */
55+
public static final OrchestrationAiModel CLAUDE_3_SONNET =
56+
new OrchestrationAiModel("anthropic--claude-3-sonnet");
57+
58+
/** Anthropic Claude 3 Haiku model */
59+
public static final OrchestrationAiModel CLAUDE_3_HAIKU =
60+
new OrchestrationAiModel("anthropic--claude-3-haiku");
61+
62+
/** Anthropic Claude 3 Opus model */
63+
public static final OrchestrationAiModel CLAUDE_3_OPUS =
64+
new OrchestrationAiModel("anthropic--claude-3-opus");
65+
66+
/** Anthropic Claude 3.5 Sonnet model */
67+
public static final OrchestrationAiModel CLAUDE_3_5_SONNET =
68+
new OrchestrationAiModel("anthropic--claude-3.5-sonnet");
69+
70+
/** Amazon Titan Text Lite model */
71+
public static final OrchestrationAiModel TITAN_TEXT_LITE =
72+
new OrchestrationAiModel("amazon--titan-text-lite");
73+
74+
/** Amazon Titan Text Express model */
75+
public static final OrchestrationAiModel TITAN_TEXT_EXPRESS =
76+
new OrchestrationAiModel("amazon--titan-text-express");
77+
78+
/** Azure OpenAI GPT-3.5 Turbo chat completions model */
79+
public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo");
80+
81+
/** Azure OpenAI GPT-3.5 Turbo chat completions model */
82+
public static final OrchestrationAiModel GPT_35_TURBO_16K =
83+
new OrchestrationAiModel("gpt-35-turbo-16k");
84+
85+
/** Azure OpenAI GPT-4 chat completions model */
86+
public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4");
87+
88+
/** Azure OpenAI GPT-4-32k chat completions model */
89+
public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k");
90+
91+
/** Azure OpenAI GPT-4o chat completions model */
92+
public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o");
93+
94+
/** Azure OpenAI GPT-4o-mini chat completions model */
95+
public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini");
96+
97+
/** Google Cloud Platform Gemini 1.0 Pro model */
98+
public static final OrchestrationAiModel GEMINI_1_0_PRO =
99+
new OrchestrationAiModel("gemini-1.0-pro");
100+
101+
/** Google Cloud Platform Gemini 1.5 Pro model */
102+
public static final OrchestrationAiModel GEMINI_1_5_PRO =
103+
new OrchestrationAiModel("gemini-1.5-pro");
104+
105+
/** Google Cloud Platform Gemini 1.5 Flash model */
106+
public static final OrchestrationAiModel GEMINI_1_5_FLASH =
107+
new OrchestrationAiModel("gemini-1.5-flash");
108+
109+
OrchestrationAiModel(@Nonnull final String modelName) {
110+
this(modelName, Map.of(), "latest");
111+
}
112+
113+
@Nonnull
114+
LLMModuleConfig createConfig() {
115+
return new LLMModuleConfig()
116+
.modelName(modelName)
117+
.modelParams(modelParams)
118+
.modelVersion(modelVersion);
119+
}
120+
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
55
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
66
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
7+
import javax.annotation.Nonnull;
78
import javax.annotation.Nullable;
89
import lombok.AccessLevel;
910
import lombok.AllArgsConstructor;
1011
import lombok.NoArgsConstructor;
1112
import lombok.Value;
1213
import lombok.With;
14+
import lombok.experimental.Tolerate;
1315

1416
/**
1517
* Represents the configuration for the orchestration service. Allows for configuring the different
@@ -48,4 +50,16 @@ public class OrchestrationModuleConfig {
4850

4951
/** A content filter to filter the prompt. */
5052
@Nullable FilteringModuleConfig filteringConfig;
53+
54+
/**
55+
* Creates a new configuration with the given LLM configuration.
56+
*
57+
* @param aiModel The LLM configuration to use.
58+
* @return A new configuration with the given LLM configuration.
59+
*/
60+
@Tolerate
61+
@Nonnull
62+
public OrchestrationModuleConfig withLlmConfig(@Nonnull final OrchestrationAiModel aiModel) {
63+
return withLlmConfig(aiModel.createConfig());
64+
}
5165
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/ConfigToRequestTransformerTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.sap.ai.sdk.orchestration;
22

3-
import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.LLM_CONFIG;
3+
import static com.sap.ai.sdk.orchestration.OrchestrationUnitTest.CUSTOM_GPT_35;
44
import static org.assertj.core.api.Assertions.assertThat;
55
import static org.assertj.core.api.Assertions.assertThatThrownBy;
66

@@ -71,7 +71,7 @@ void testMessagesHistory() {
7171
var prompt = new OrchestrationPrompt("bar").messageHistory(List.of(systemMessage));
7272
var actual =
7373
ConfigToRequestTransformer.toCompletionPostRequest(
74-
prompt, new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG));
74+
prompt, new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35));
7575

7676
assertThat(actual.getMessagesHistory()).containsExactly(systemMessage);
7777
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
1717
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
1818
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
19+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
1920
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_0;
2021
import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_4;
2122
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
@@ -39,7 +40,6 @@
3940
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
4041
import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult;
4142
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
42-
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
4343
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
4444
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
4545
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
@@ -62,15 +62,13 @@
6262
*/
6363
@WireMockTest
6464
class OrchestrationUnitTest {
65-
static final LLMModuleConfig LLM_CONFIG =
66-
new LLMModuleConfig()
67-
.modelName("gpt-35-turbo-16k")
68-
.modelParams(
69-
Map.of(
70-
"max_tokens", 50,
71-
"temperature", 0.1,
72-
"frequency_penalty", 0,
73-
"presence_penalty", 0));
65+
static final OrchestrationAiModel CUSTOM_GPT_35 =
66+
GPT_35_TURBO_16K.withModelParams(
67+
Map.of(
68+
"max_tokens", 50,
69+
"temperature", 0.1,
70+
"frequency_penalty", 0,
71+
"presence_penalty", 0));
7472
private final Function<String, InputStream> fileLoader =
7573
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
7674

@@ -106,7 +104,7 @@ void setup(WireMockRuntimeInfo server) {
106104
.forDeploymentByScenario("orchestration")
107105
.withResourceGroup("my-resource-group");
108106
client = new OrchestrationClient(deployment);
109-
config = new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
107+
config = new OrchestrationModuleConfig().withLlmConfig(CUSTOM_GPT_35);
110108
prompt = new OrchestrationPrompt("Hello World! Why is this phrase so famous?");
111109
}
112110

@@ -146,6 +144,7 @@ void testTemplating() throws IOException {
146144
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
147145
assertThat(response.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
148146
var llm = (LLMModuleResultSynchronous) response.getModuleResults().getLlm();
147+
assertThat(llm).isNotNull();
149148
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
150149
assertThat(llm.getObject()).isEqualTo("chat.completion");
151150
assertThat(llm.getCreated()).isEqualTo(1721224505);
@@ -315,6 +314,7 @@ void maskingPseudonymization() throws IOException {
315314

316315
assertThat(response).isNotNull();
317316
GenericModuleResult inputMasking = response.getModuleResults().getInputMasking();
317+
assertThat(inputMasking).isNotNull();
318318
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
319319
assertThat(inputMasking.getData()).isNotNull();
320320
assertThat(result.getContent()).contains("Hi Mallory");

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.sap.ai.sdk.app.controllers;
22

3+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO;
4+
35
import com.sap.ai.sdk.orchestration.OrchestrationChatResponse;
46
import com.sap.ai.sdk.orchestration.OrchestrationClient;
57
import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig;
@@ -13,7 +15,6 @@
1315
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
1416
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
1517
import com.sap.ai.sdk.orchestration.client.model.InputFilteringConfig;
16-
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
1718
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
1819
import com.sap.ai.sdk.orchestration.client.model.OutputFilteringConfig;
1920
import com.sap.ai.sdk.orchestration.client.model.Template;
@@ -30,12 +31,8 @@
3031
@RestController
3132
@RequestMapping("/orchestration")
3233
class OrchestrationController {
33-
static final LLMModuleConfig LLM_CONFIG =
34-
new LLMModuleConfig().modelName("gpt-35-turbo").modelParams(Map.of());
35-
3634
private final OrchestrationClient client = new OrchestrationClient();
37-
private final OrchestrationModuleConfig config =
38-
new OrchestrationModuleConfig().withLlmConfig(LLM_CONFIG);
35+
OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO);
3936

4037
/**
4138
* Chat request to OpenAI through the Orchestration service with a simple prompt.
@@ -170,7 +167,7 @@ public OrchestrationChatResponse maskingAnonymization() {
170167

171168
/**
172169
* Let the orchestration service a response to a hypothetical user who provided feedback on the AI
173-
* SDK. Pseydonymize the user's name and location to protect their privacy.
170+
* SDK. Pseudonymize the user's name and location to protect their privacy.
174171
*
175172
* @return the result object
176173
*/

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
import com.sap.ai.sdk.orchestration.client.model.LLMModuleResultSynchronous;
1111
import java.util.List;
1212
import java.util.Map;
13+
import lombok.extern.slf4j.Slf4j;
1314
import org.assertj.core.api.InstanceOfAssertFactories;
1415
import org.junit.jupiter.api.BeforeEach;
1516
import org.junit.jupiter.api.Test;
1617

18+
@Slf4j
1719
class OrchestrationTest {
1820
OrchestrationController controller;
1921

@@ -32,6 +34,9 @@ void testCompletion() {
3234

3335
@Test
3436
void testTemplate() {
37+
assertThat(controller.config.getLlmConfig()).isNotNull();
38+
final var modelName = controller.config.getLlmConfig().getModelName();
39+
3540
final var response = controller.template();
3641
final var result = response.getOriginalResponse();
3742

@@ -43,7 +48,7 @@ void testTemplate() {
4348
assertThat(llm.getId()).isNotEmpty();
4449
assertThat(llm.getObject()).isEqualTo("chat.completion");
4550
assertThat(llm.getCreated()).isGreaterThan(1);
46-
assertThat(llm.getModel()).isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
51+
assertThat(llm.getModel()).isEqualTo(modelName);
4752
var choices = llm.getChoices();
4853
assertThat(choices.get(0).getIndex()).isZero();
4954
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();
@@ -57,8 +62,7 @@ void testTemplate() {
5762
var orchestrationResult = ((LLMModuleResultSynchronous) result.getOrchestrationResult());
5863
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
5964
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
60-
assertThat(orchestrationResult.getModel())
61-
.isEqualTo(OrchestrationController.LLM_CONFIG.getModelName());
65+
assertThat(orchestrationResult.getModel()).isEqualTo(modelName);
6266
choices = orchestrationResult.getChoices();
6367
assertThat(choices.get(0).getIndex()).isZero();
6468
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/ScenarioTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
import org.junit.jupiter.api.DisplayName;
1212
import org.junit.jupiter.api.Test;
1313

14-
public class ScenarioTest {
14+
class ScenarioTest {
1515

1616
@Test
1717
@DisplayName("Declared OpenAI models must match AI Core's available OpenAI models")
1818
@SneakyThrows
19-
public void openAiModelAvailability() {
19+
void openAiModelAvailability() {
2020

2121
// Gather AI Core's list of available OpenAI models
2222
final var aiModelList = new ScenarioController().getModels().getResources();

0 commit comments

Comments
 (0)