Skip to content

Commit e352343

Browse files
newtorka-djjtang1985bot-sdk-jsMatKuhr
authored
[Orchestration Convenience] Typed model parameters (#180)
* Initial * Docs * Adjust solution * Update ORCHESTRATION_CHAT_COMPLETION.md * Adjust solution to (value) type safety. * Remove unpopular parameters * Formatting * Add doc link to JavaDoc * Update sample code --------- Co-authored-by: Alexander Dümont <[email protected]> Co-authored-by: Junjie Tang <[email protected]> Co-authored-by: SAP Cloud SDK Bot <[email protected]> Co-authored-by: Matthias Kuhr <[email protected]>
1 parent d4abd9d commit e352343

File tree

5 files changed

+127
-15
lines changed

5 files changed

+127
-15
lines changed

docs/guides/ORCHESTRATION_CHAT_COMPLETION.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,10 @@ Change your LLM configuration to add model parameters:
200200
```java
201201
OrchestrationAiModel customGPT4O =
202202
OrchestrationAiModel.GPT_4O
203-
.withParams(
204-
Map.of(
205-
"max_tokens", 50,
206-
"temperature", 0.1,
207-
"frequency_penalty", 0,
208-
"presence_penalty", 0))
203+
.withParam(MAX_TOKENS, 50)
204+
.withParam(TEMPERATURE, 0.1)
205+
.withParam(FREQUENCY_PENALTY, 0)
206+
.withParam(PRESENCE_PENALTY, 0)
209207
.withVersion("2024-05-13");
210208
```
211209

@@ -225,4 +223,4 @@ var prompt = new OrchestrationPrompt(Map.of("your-input-parameter", "your-param-
225223
new OrchestrationClient().executeRequestFromJsonModuleConfig(prompt, configJson);
226224
```
227225

228-
While this is not recommended for long term use, it can be useful for creating demos and PoCs.
226+
While this is not recommended for long term use, it can be useful for creating demos and PoCs.

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package com.sap.ai.sdk.orchestration;
22

33
import com.sap.ai.sdk.orchestration.model.LLMModuleConfig;
4+
import java.util.LinkedHashMap;
45
import java.util.Map;
56
import javax.annotation.Nonnull;
7+
import javax.annotation.Nullable;
68
import lombok.AllArgsConstructor;
79
import lombok.Value;
810
import lombok.With;
@@ -25,6 +27,10 @@ public class OrchestrationAiModel {
2527
* "frequency_penalty", 0,
2628
* "presence_penalty", 0)
2729
* }</pre>
30+
*
31+
* @link <a
32+
* href="https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/harmonized-api">SAP
33+
* AI Core: Orchestration - Harmonized API</a>
2834
*/
2935
Map<String, Object> params;
3036

@@ -114,4 +120,72 @@ public class OrchestrationAiModel {
114120
LLMModuleConfig createConfig() {
115121
return LLMModuleConfig.create().modelName(name).modelParams(params).modelVersion(version);
116122
}
123+
124+
/**
125+
* Additional parameter on this model.
126+
*
127+
* @param key the parameter key.
128+
* @param value the parameter value, nullable.
129+
* @return A new model with the additional parameter.
130+
* @link <a
131+
* href="https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/harmonized-api">SAP
132+
* AI Core: Orchestration - Harmonized API</a>
133+
*/
134+
@Nonnull
135+
public OrchestrationAiModel withParam(@Nonnull final String key, @Nullable final Object value) {
136+
final var params = new LinkedHashMap<>(getParams());
137+
params.put(key, value);
138+
return withParams(params);
139+
}
140+
141+
/**
142+
* Additional parameter on this model.
143+
*
144+
* @param param the parameter key.
145+
* @param value the parameter value, nullable.
146+
* @param <ValueT> the parameter value type.
147+
* @return A new model with the additional parameter.
148+
* @link <a
149+
* href="https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/harmonized-api">SAP
150+
* AI Core: Orchestration - Harmonized API</a>
151+
*/
152+
@Nonnull
153+
public <ValueT> OrchestrationAiModel withParam(
154+
@Nonnull final Parameter<ValueT> param, @Nullable final ValueT value) {
155+
return withParam(param.getName(), value);
156+
}
157+
158+
/**
159+
* Parameter key for a model.
160+
*
161+
* @param <ValueT> the parameter value type.
162+
*/
163+
@FunctionalInterface
164+
public interface Parameter<ValueT> {
165+
/** The maximum number of tokens to generate. */
166+
Parameter<Integer> MAX_TOKENS = () -> "max_tokens";
167+
168+
/** The sampling temperature. */
169+
Parameter<Number> TEMPERATURE = () -> "temperature";
170+
171+
/** The frequency penalty. */
172+
Parameter<Number> FREQUENCY_PENALTY = () -> "frequency_penalty";
173+
174+
/** The presence penalty. */
175+
Parameter<Number> PRESENCE_PENALTY = () -> "presence_penalty";
176+
177+
/** The probability mass to be considered . */
178+
Parameter<Number> TOP_P = () -> "top_p";
179+
180+
/** The number of chat completion choices to generate for each input message. */
181+
Parameter<Integer> N = () -> "n";
182+
183+
/**
184+
* The name of the parameter.
185+
*
186+
* @return the name of the parameter.
187+
*/
188+
@Nonnull
189+
String getName();
190+
}
117191
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
44
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O;
5+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.MAX_TOKENS;
56
import static org.assertj.core.api.Assertions.assertThat;
67
import static org.assertj.core.api.Assertions.assertThatThrownBy;
78

@@ -65,6 +66,44 @@ void testDpiMaskingConfig() {
6566
.hasSize(1);
6667
}
6768

69+
@Test
70+
void testParams() {
71+
// test withParams(Map<String, Object>)
72+
{
73+
var params = Map.<String, Object>of("foo", "bar", "fizz", "buzz");
74+
75+
var modelA = GPT_4O.withParams(params);
76+
var modelB = modelA.withParams(params);
77+
assertThat(modelA).isEqualTo(modelB);
78+
79+
var modelC = modelA.withParams(Map.of("foo", "bar"));
80+
assertThat(modelA).isNotEqualTo(modelC);
81+
82+
var modelD = modelA.withParams(Map.of("foo", "bazz"));
83+
assertThat(modelA).isNotEqualTo(modelD);
84+
}
85+
86+
// test withParam(String, Object)
87+
{
88+
var modelA = GPT_4O.withParam("foo", "bar");
89+
var modelB = modelA.withParam("foo", "bar");
90+
assertThat(modelA).isEqualTo(modelB);
91+
92+
var modelC = modelA.withParam("foo", "bazz");
93+
assertThat(modelA).isNotEqualTo(modelC);
94+
}
95+
96+
// test withParam(Parameter, Object)
97+
{
98+
var modelA = GPT_4O.withParam(MAX_TOKENS, 10);
99+
var modelB = modelA.withParam(MAX_TOKENS, 10);
100+
assertThat(modelA).isEqualTo(modelB);
101+
102+
var modelC = modelA.withParam(MAX_TOKENS, 20);
103+
assertThat(modelA).isNotEqualTo(modelC);
104+
}
105+
}
106+
68107
@Test
69108
void testLLMConfig() {
70109
Map<String, Object> params = Map.of("foo", "bar");

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE;
2020
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
2121
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
22+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.*;
2223
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
2324
import static org.assertj.core.api.Assertions.assertThat;
2425
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -52,12 +53,12 @@
5253
@WireMockTest
5354
class OrchestrationUnitTest {
5455
static final OrchestrationAiModel CUSTOM_GPT_35 =
55-
GPT_35_TURBO_16K.withParams(
56-
Map.of(
57-
"max_tokens", 50,
58-
"temperature", 0.1,
59-
"frequency_penalty", 0,
60-
"presence_penalty", 0));
56+
GPT_35_TURBO_16K
57+
.withParam(MAX_TOKENS, 50)
58+
.withParam(TEMPERATURE, 0.1)
59+
.withParam(FREQUENCY_PENALTY, 0)
60+
.withParam(PRESENCE_PENALTY, 0);
61+
6162
private final Function<String, InputStream> fileLoader =
6263
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
6364

sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.sap.ai.sdk.app.controllers;
22

33
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO;
4+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.TEMPERATURE;
45

56
import com.sap.ai.sdk.orchestration.AzureContentFilter;
67
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
@@ -26,8 +27,7 @@
2627
class OrchestrationController {
2728
private final OrchestrationClient client = new OrchestrationClient();
2829
OrchestrationModuleConfig config =
29-
new OrchestrationModuleConfig()
30-
.withLlmConfig(GPT_35_TURBO.withParams(Map.of("temperature", 0.0)));
30+
new OrchestrationModuleConfig().withLlmConfig(GPT_35_TURBO.withParam(TEMPERATURE, 0.0));
3131

3232
/**
3333
* Chat request to OpenAI through the Orchestration service with a simple prompt.

0 commit comments

Comments
 (0)