Skip to content

Commit 0b9a107

Browse files
committed
Work in progress
1 parent 869e83a commit 0b9a107

File tree

4 files changed

+65
-9
lines changed

4 files changed

+65
-9
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
44
import java.util.Map;
55
import javax.annotation.Nonnull;
6+
import javax.annotation.Nullable;
67
import lombok.Getter;
78

89
/** Large language models available in Orchestration. */
910
// https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/models-and-scenarios-in-generative-ai-hub
11+
@Getter
1012
public class OrchestrationAiModel {
11-
@Getter private final LLMModuleConfig config;
13+
private final LLMModuleConfig config;
1214

1315
/** IBM Granite 13B chat completions model */
1416
public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
@@ -116,15 +118,64 @@ public class OrchestrationAiModel {
116118
new OrchestrationAiModel("gemini-1.5-flash");
117119

118120
OrchestrationAiModel(@Nonnull final String modelName) {
119-
config = LLMModuleConfig.create().modelName(modelName).modelParams(Map.of());
121+
config = new LLMModuleConfig().modelName(modelName).modelParams(Map.of());
120122
}
121123

122-
private OrchestrationAiModel(
123-
@Nonnull final String modelName, Map<String, ? extends Number> modelParams) {
124-
config = LLMModuleConfig.create().modelName(modelName).modelParams(modelParams);
124+
// private OrchestrationAiModel(
125+
// @Nonnull final String modelName, @Nonnull final Map<String, ? extends Number> modelParams)
126+
// {
127+
// config = new LLMModuleConfig().modelName(modelName).modelParams(modelParams);
128+
// }
129+
130+
private OrchestrationAiModel(@Nonnull final String modelName, @Nonnull final Object modelParams) {
131+
config = new LLMModuleConfig().modelName(modelName).modelParams(modelParams);
132+
}
133+
134+
/**
135+
* Set model version on this model.
136+
*
137+
* <pre>{@code
138+
* .modelVersion("latest)
139+
* }</pre>
140+
*
141+
* @param version The new version.
142+
* @return New instance of this class with new version.
143+
*/
144+
@Nonnull
145+
public OrchestrationAiModel modelVersion(@Nullable final String version) {
146+
// Question: I need a map but only got an object as modelParams. How are modelParams
147+
// structured?
148+
final var model = new OrchestrationAiModel(config.getModelName(), config.getModelParams());
149+
model.config.setModelVersion(version);
150+
// Question: Is modelVersion not lost as soon as we call modelParams?
151+
// Do we need to propagate this in that function as well?
152+
return model;
125153
}
126154

127-
public OrchestrationAiModel modelParams(Map<String, ? extends Number> modelParams) {
128-
return new OrchestrationAiModel(config.getModelName(), modelParams);
155+
/**
156+
* Set model parameters on this model.
157+
*
158+
* <pre>{@code
159+
* .modelParams(
160+
* Map.of(
161+
* "max_tokens", 50,
162+
* "temperature", 0.1,
163+
* "frequency_penalty", 0,
164+
* "presence_penalty", 0));
165+
* }</pre>
166+
*
167+
* @param modelParams Map of parameters.
168+
* @return New instance of this class.
169+
*/
170+
@Nonnull
171+
public OrchestrationAiModel modelParams(@Nonnull final Object modelParams) {
172+
return new OrchestrationAiModel(config.getModelName(), modelParams)
173+
.modelVersion(config.getModelVersion());
129174
}
175+
176+
// @Nonnull
177+
// public OrchestrationAiModel modelParams(
178+
// @Nonnull final Map<String, ? extends Number> modelParams) {
179+
// return new OrchestrationAiModel(config.getModelName(), modelParams);
180+
// }
130181
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class OrchestrationUnitTest {
7070
"temperature", 0.1,
7171
"frequency_penalty", 0,
7272
"presence_penalty", 0))
73+
.modelVersion("latest")
7374
.getConfig();
7475
private final Function<String, InputStream> fileLoader =
7576
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
@@ -146,6 +147,7 @@ void testTemplating() throws IOException {
146147
.isEqualTo("Reply with 'Orchestration Service is working!' in German");
147148
assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user");
148149
var llm = (LLMModuleResultSynchronous) result.getModuleResults().getLlm();
150+
assertThat(llm).isNotNull();
149151
assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj");
150152
assertThat(llm.getObject()).isEqualTo("chat.completion");
151153
assertThat(llm.getCreated()).isEqualTo(1721224505);
@@ -313,6 +315,7 @@ void maskingAnonymization() throws IOException {
313315

314316
assertThat(result).isNotNull();
315317
GenericModuleResult inputMasking = result.getModuleResults().getInputMasking();
318+
assertThat(inputMasking).isNotNull();
316319
assertThat(inputMasking.getMessage()).isEqualTo("Input to LLM is masked successfully.");
317320
assertThat(inputMasking.getData()).isNotNull();
318321
final var choices = ((LLMModuleResultSynchronous) result.getOrchestrationResult()).getChoices();

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
@RequestMapping("/orchestration")
3333
class OrchestrationController {
3434
private final OrchestrationClient client = new OrchestrationClient();
35-
OrchestrationModuleConfig config = new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO.getConfig());
35+
OrchestrationModuleConfig config =
36+
new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO.getConfig());
3637

3738
/**
3839
* Chat request to OpenAI through the Orchestration service with a simple prompt.

sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ void testTemplate() {
6767
var orchestrationResult = ((LLMModuleResultSynchronous) result.getOrchestrationResult());
6868
assertThat(orchestrationResult.getObject()).isEqualTo("chat.completion");
6969
assertThat(orchestrationResult.getCreated()).isGreaterThan(1);
70-
assertThat(result.getOrchestrationResult().getModel()).isEqualTo(model);
70+
assertThat(((LLMModuleResultSynchronous) result.getOrchestrationResult()).getModel())
71+
.isEqualTo(model);
7172
choices = ((LLMModuleResultSynchronous) orchestrationResult).getChoices();
7273
assertThat(choices.get(0).getIndex()).isZero();
7374
assertThat(choices.get(0).getMessage().getContent()).isNotEmpty();

0 commit comments

Comments
 (0)