Skip to content

Commit f56e6c5

Browse files
committed
Adjust solution
1 parent 2c7fc4b commit f56e6c5

File tree

4 files changed

+109
-129
lines changed

4 files changed

+109
-129
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
package com.sap.ai.sdk.orchestration;
22

33
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
4+
import java.util.LinkedHashMap;
45
import java.util.Map;
56
import javax.annotation.Nonnull;
6-
import lombok.AccessLevel;
7+
import javax.annotation.Nullable;
78
import lombok.AllArgsConstructor;
8-
import lombok.EqualsAndHashCode;
9-
import lombok.Getter;
9+
import lombok.RequiredArgsConstructor;
10+
import lombok.Value;
1011
import lombok.With;
11-
import lombok.experimental.Tolerate;
1212

1313
/** Large language models available in Orchestration. */
14+
@Value
1415
@With
1516
@AllArgsConstructor
16-
@Getter(AccessLevel.PACKAGE)
17-
@EqualsAndHashCode
1817
public class OrchestrationAiModel {
1918
/** The name of the model */
20-
private final String name;
19+
String name;
2120

2221
/**
2322
* Optional parameters on this model.
@@ -30,10 +29,10 @@ public class OrchestrationAiModel {
3029
* "presence_penalty", 0)
3130
* }</pre>
3231
*/
33-
private final Map<String, Object> params;
32+
Map<String, Object> params;
3433

3534
/** The version of the model, defaults to "latest". */
36-
private final String version;
35+
String version;
3736

3837
/** IBM Granite 13B chat completions model */
3938
public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
@@ -80,28 +79,23 @@ public class OrchestrationAiModel {
8079
new OrchestrationAiModel("amazon--titan-text-express");
8180

8281
/** Azure OpenAI GPT-3.5 Turbo chat completions model */
83-
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_35_TURBO =
84-
new Parameterized<>("gpt-35-turbo");
82+
public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo");
8583

8684
/** Azure OpenAI GPT-3.5 Turbo chat completions model */
87-
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_35_TURBO_16K =
88-
new Parameterized<>("gpt-35-turbo-16k");
85+
public static final OrchestrationAiModel GPT_35_TURBO_16K =
86+
new OrchestrationAiModel("gpt-35-turbo-16k");
8987

9088
/** Azure OpenAI GPT-4 chat completions model */
91-
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4 =
92-
new Parameterized<>("gpt-4");
89+
public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4");
9390

9491
/** Azure OpenAI GPT-4-32k chat completions model */
95-
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4_32K =
96-
new Parameterized<>("gpt-4-32k");
92+
public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k");
9793

9894
/** Azure OpenAI GPT-4o chat completions model */
99-
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4O =
100-
new Parameterized<>("gpt-4o");
95+
public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o");
10196

10297
/** Azure OpenAI GPT-4o-mini chat completions model */
103-
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4O_MINI =
104-
new Parameterized<>("gpt-4o-mini");
98+
public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini");
10599

106100
/** Google Cloud Platform Gemini 1.0 Pro model */
107101
public static final OrchestrationAiModel GEMINI_1_0_PRO =
@@ -125,26 +119,56 @@ LLMModuleConfig createConfig() {
125119
}
126120

127121
/**
128-
* Subclass to allow for parameterized models.
122+
* Additional parameter on this model.
129123
*
130-
* @param <T> The type of parameters for this model.
124+
* @param key the parameter key.
125+
* @param value the parameter value, nullable.
126+
* @return A new model with the additional parameter.
131127
*/
132-
public static final class Parameterized<T extends OrchestrationAiModelParameters>
133-
extends OrchestrationAiModel {
134-
private Parameterized(@Nonnull final String name) {
135-
super(name);
136-
}
137-
138-
/**
139-
* Set the typed parameters for this model.
140-
*
141-
* @param params The parameters for this model.
142-
* @return The model with the parameters set.
143-
*/
144-
@Tolerate
145-
@Nonnull
146-
public OrchestrationAiModel withParams(@Nonnull final T params) {
147-
return super.withParams(params.getParams());
148-
}
128+
@Nonnull
129+
public OrchestrationAiModel withParam(@Nonnull final String key, @Nullable final Object value) {
130+
final var params = new LinkedHashMap<>(getParams());
131+
params.put(key, value);
132+
return withParams(params);
133+
}
134+
135+
/**
136+
* Additional parameter on this model.
137+
*
138+
* @param param the parameter key.
139+
* @param value the parameter value, nullable.
140+
* @return A new model with the additional parameter.
141+
*/
142+
@Nonnull
143+
public OrchestrationAiModel withParam(
144+
@Nonnull final Parameter param, @Nullable final Object value) {
145+
return withParam(param.value, value);
146+
}
147+
148+
/** Parameter key for a model. */
149+
@RequiredArgsConstructor
150+
public enum Parameter {
151+
/** The maximum number of tokens to generate. */
152+
MAX_TOKENS("max_tokens"),
153+
/** The sampling temperature. */
154+
TEMPERATURE("temperature"),
155+
/** The frequency penalty. */
156+
FREQUENCY_PENALTY("frequency_penalty"),
157+
/** The presence penalty. */
158+
PRESENCE_PENALTY("presence_penalty"),
159+
/** The maximum number of tokens for completion */
160+
MAX_COMPLETION_TOKENS("max_completion_tokens"),
161+
/** The probability mass to be considered . */
162+
TOP_P("top_p"),
163+
/** The toggle to enable partial message delta. */
164+
STREAM("stream"),
165+
/** The options for streaming response. */
166+
STREAM_OPTIONS("stream_options"),
167+
/** The tokens where the API will stop generating further tokens. */
168+
STOP("stop"),
169+
/** The number of chat completion choices to generate for each input message. */
170+
N("n");
171+
172+
private final String value;
149173
}
150174
}

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModelParameters.java

Lines changed: 0 additions & 86 deletions
This file was deleted.

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
44
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O;
5+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.MAX_TOKENS;
56
import static org.assertj.core.api.Assertions.assertThat;
67
import static org.assertj.core.api.Assertions.assertThatThrownBy;
78

@@ -65,6 +66,44 @@ void testDpiMaskingConfig() {
6566
.hasSize(1);
6667
}
6768

69+
@Test
70+
void testParams() {
71+
// test withParams(Map<String, Object>)
72+
{
73+
var params = Map.<String, Object>of("foo", "bar", "fizz", "buzz");
74+
75+
var modelA = GPT_4O.withParams(params);
76+
var modelB = modelA.withParams(params);
77+
assertThat(modelA).isEqualTo(modelB);
78+
79+
var modelC = modelA.withParams(Map.of("foo", "bar"));
80+
assertThat(modelA).isNotEqualTo(modelC);
81+
82+
var modelD = modelA.withParams(Map.of("foo", "bazz"));
83+
assertThat(modelA).isNotEqualTo(modelD);
84+
}
85+
86+
// test withParam(String, Object)
87+
{
88+
var modelA = GPT_4O.withParam("foo", "bar");
89+
var modelB = modelA.withParam("foo", "bar");
90+
assertThat(modelA).isEqualTo(modelB);
91+
92+
var modelC = modelA.withParam("foo", "bazz");
93+
assertThat(modelA).isNotEqualTo(modelC);
94+
}
95+
96+
// test withParam(Parameter, Object)
97+
{
98+
var modelA = GPT_4O.withParam(MAX_TOKENS, 10);
99+
var modelB = modelA.withParam(MAX_TOKENS, 10);
100+
assertThat(modelA).isEqualTo(modelB);
101+
102+
var modelC = modelA.withParam(MAX_TOKENS, 20);
103+
assertThat(modelA).isNotEqualTo(modelC);
104+
}
105+
}
106+
68107
@Test
69108
void testLLMConfig() {
70109
Map<String, Object> params = Map.of("foo", "bar");

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE;
2020
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
2121
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
22-
import static com.sap.ai.sdk.orchestration.OrchestrationAiModelParameters.GPT.params;
22+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.*;
2323
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
2424
import static org.assertj.core.api.Assertions.assertThat;
2525
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -53,8 +53,11 @@
5353
@WireMockTest
5454
class OrchestrationUnitTest {
5555
static final OrchestrationAiModel CUSTOM_GPT_35 =
56-
GPT_35_TURBO_16K.withParams(
57-
params().maxTokens(50).temperature(0.1).frequencyPenalty(0).presencePenalty(0));
56+
GPT_35_TURBO_16K
57+
.withParam(MAX_TOKENS, 50)
58+
.withParam(TEMPERATURE, 0.1)
59+
.withParam(FREQUENCY_PENALTY, 0)
60+
.withParam(PRESENCE_PENALTY, 0);
5861

5962
private final Function<String, InputStream> fileLoader =
6063
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));

0 commit comments

Comments
 (0)