Skip to content

Commit bce9be0

Browse files
committed
Initial
1 parent bf1f491 commit bce9be0

File tree

3 files changed

+134
-17
lines changed

3 files changed

+134
-17
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,21 @@
33
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
44
import java.util.Map;
55
import javax.annotation.Nonnull;
6+
import lombok.AccessLevel;
67
import lombok.AllArgsConstructor;
7-
import lombok.Value;
8+
import lombok.EqualsAndHashCode;
9+
import lombok.Getter;
810
import lombok.With;
11+
import lombok.experimental.Tolerate;
912

1013
/** Large language models available in Orchestration. */
11-
@Value
1214
@With
1315
@AllArgsConstructor
16+
@Getter(AccessLevel.PACKAGE)
17+
@EqualsAndHashCode
1418
public class OrchestrationAiModel {
1519
/** The name of the model */
16-
String name;
20+
private final String name;
1721

1822
/**
1923
* Optional parameters on this model.
@@ -26,10 +30,10 @@ public class OrchestrationAiModel {
2630
* "presence_penalty", 0)
2731
* }</pre>
2832
*/
29-
Map<String, Object> params;
33+
private final Map<String, Object> params;
3034

3135
/** The version of the model, defaults to "latest". */
32-
String version;
36+
private final String version;
3337

3438
/** IBM Granite 13B chat completions model */
3539
public static final OrchestrationAiModel IBM_GRANITE_13B_CHAT =
@@ -76,23 +80,28 @@ public class OrchestrationAiModel {
7680
new OrchestrationAiModel("amazon--titan-text-express");
7781

7882
/** Azure OpenAI GPT-3.5 Turbo chat completions model */
79-
public static final OrchestrationAiModel GPT_35_TURBO = new OrchestrationAiModel("gpt-35-turbo");
83+
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_35_TURBO =
84+
new Parameterized<>("gpt-35-turbo");
8085

8186
/** Azure OpenAI GPT-3.5 Turbo chat completions model */
82-
public static final OrchestrationAiModel GPT_35_TURBO_16K =
83-
new OrchestrationAiModel("gpt-35-turbo-16k");
87+
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_35_TURBO_16K =
88+
new Parameterized<>("gpt-35-turbo-16k");
8489

8590
/** Azure OpenAI GPT-4 chat completions model */
86-
public static final OrchestrationAiModel GPT_4 = new OrchestrationAiModel("gpt-4");
91+
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4 =
92+
new Parameterized<>("gpt-4");
8793

8894
/** Azure OpenAI GPT-4-32k chat completions model */
89-
public static final OrchestrationAiModel GPT_4_32K = new OrchestrationAiModel("gpt-4-32k");
95+
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4_32K =
96+
new Parameterized<>("gpt-4-32k");
9097

9198
/** Azure OpenAI GPT-4o chat completions model */
92-
public static final OrchestrationAiModel GPT_4O = new OrchestrationAiModel("gpt-4o");
99+
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4O =
100+
new Parameterized<>("gpt-4o");
93101

94102
/** Azure OpenAI GPT-4o-mini chat completions model */
95-
public static final OrchestrationAiModel GPT_4O_MINI = new OrchestrationAiModel("gpt-4o-mini");
103+
public static final Parameterized<OrchestrationAiModelParameters.GPT> GPT_4O_MINI =
104+
new Parameterized<>("gpt-4o-mini");
96105

97106
/** Google Cloud Platform Gemini 1.0 Pro model */
98107
public static final OrchestrationAiModel GEMINI_1_0_PRO =
@@ -114,4 +123,28 @@ public class OrchestrationAiModel {
114123
LLMModuleConfig createConfig() {
115124
return new LLMModuleConfig().modelName(name).modelParams(params).modelVersion(version);
116125
}
126+
127+
/**
128+
* Subclass to allow for parameterized models.
129+
*
130+
* @param <T> The type of parameters for this model.
131+
*/
132+
public static final class Parameterized<T extends OrchestrationAiModelParameters>
133+
extends OrchestrationAiModel {
134+
private Parameterized(@Nonnull final String name) {
135+
super(name);
136+
}
137+
138+
/**
139+
* Set the typed parameters for this model.
140+
*
141+
* @param params The parameters for this model.
142+
* @return The model with the parameters set.
143+
*/
144+
@Tolerate
145+
@Nonnull
146+
public OrchestrationAiModel withParams(@Nonnull final T params) {
147+
return super.withParams(params.getParams());
148+
}
149+
}
117150
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package com.sap.ai.sdk.orchestration;
2+
3+
import java.util.Map;
4+
import javax.annotation.Nonnull;
5+
6+
/** Helper interface to define typed parameters. */
7+
@FunctionalInterface
8+
public interface OrchestrationAiModelParameters {
9+
/**
10+
* Get the parameters.
11+
*
12+
* @return the parameters.
13+
*/
14+
@Nonnull
15+
Map<String, Object> getParams();
16+
17+
/** GPT model parameters. */
18+
interface GPT extends OrchestrationAiModelParameters {
19+
/**
20+
* Create a new builder.
21+
*
22+
* @return the builder.
23+
*/
24+
@Nonnull
25+
static GPT.Builder0 params() {
26+
return maxTokens ->
27+
temperature ->
28+
frequencyPenalty ->
29+
presencePenalty ->
30+
() ->
31+
Map.of(
32+
"max_tokens", maxTokens,
33+
"temperature", temperature,
34+
"frequency_penalty", frequencyPenalty,
35+
"presence_penalty", presencePenalty);
36+
}
37+
38+
/** Builder for GPT model parameters. */
39+
interface Builder0 {
40+
/**
41+
* Set the max tokens.
42+
*
43+
* @param maxTokens the max tokens.
44+
* @return the next builder.
45+
*/
46+
@Nonnull
47+
GPT.Builder1 maxTokens(@Nonnull final Number maxTokens);
48+
}
49+
50+
/** Builder for GPT model parameters. */
51+
interface Builder1 {
52+
/**
53+
* Set the temperature.
54+
*
55+
* @param temperature the temperature.
56+
* @return the next builder.
57+
*/
58+
@Nonnull
59+
GPT.Builder2 temperature(@Nonnull final Number temperature);
60+
}
61+
62+
/** Builder for GPT model parameters. */
63+
interface Builder2 {
64+
/**
65+
* Set the frequency penalty.
66+
*
67+
* @param frequencyPenalty the frequency penalty.
68+
* @return the next builder.
69+
*/
70+
@Nonnull
71+
GPT.Builder3 frequencyPenalty(@Nonnull final Number frequencyPenalty);
72+
}
73+
74+
/** Builder for GPT model parameters. */
75+
interface Builder3 {
76+
/**
77+
* Set the presence penalty.
78+
*
79+
* @param presencePenalty the presence penalty.
80+
* @return the final typed parameter object.
81+
*/
82+
@Nonnull
83+
GPT presencePenalty(@Nonnull final Number presencePenalty);
84+
}
85+
}
86+
}

orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE;
2020
import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM;
2121
import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_35_TURBO_16K;
22+
import static com.sap.ai.sdk.orchestration.OrchestrationAiModelParameters.GPT.params;
2223
import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST;
2324
import static org.assertj.core.api.Assertions.assertThat;
2425
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -53,11 +54,8 @@
5354
class OrchestrationUnitTest {
5455
static final OrchestrationAiModel CUSTOM_GPT_35 =
5556
GPT_35_TURBO_16K.withParams(
56-
Map.of(
57-
"max_tokens", 50,
58-
"temperature", 0.1,
59-
"frequency_penalty", 0,
60-
"presence_penalty", 0));
57+
params().maxTokens(50).temperature(0.1).frequencyPenalty(0).presencePenalty(0));
58+
6159
private final Function<String, InputStream> fileLoader =
6260
filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename));
6361

0 commit comments

Comments
 (0)