Skip to content

Commit f360c60

Browse files
committed
Adjust solution to (value) type safety.
1 parent f56e6c5 commit f360c60

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationAiModel.java

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
44
import java.util.LinkedHashMap;
5+
import java.util.List;
56
import java.util.Map;
67
import javax.annotation.Nonnull;
78
import javax.annotation.Nullable;
89
import lombok.AllArgsConstructor;
9-
import lombok.RequiredArgsConstructor;
1010
import lombok.Value;
1111
import lombok.With;
1212

@@ -137,38 +137,58 @@ public OrchestrationAiModel withParam(@Nonnull final String key, @Nullable final
137137
*
138138
* @param param the parameter key.
139139
* @param value the parameter value, nullable.
140+
* @param <ValueT> the parameter value type.
140141
* @return A new model with the additional parameter.
141142
*/
142143
@Nonnull
143-
public OrchestrationAiModel withParam(
144-
@Nonnull final Parameter param, @Nullable final Object value) {
145-
return withParam(param.value, value);
144+
public <ValueT> OrchestrationAiModel withParam(
145+
@Nonnull final Parameter<ValueT> param, @Nullable final ValueT value) {
146+
return withParam(param.getName(), value);
146147
}
147148

148-
/** Parameter key for a model. */
149-
@RequiredArgsConstructor
150-
public enum Parameter {
149+
/**
150+
* Parameter key for a model.
151+
*
152+
* @param <ValueT> the parameter value type.
153+
*/
154+
@FunctionalInterface
155+
public interface Parameter<ValueT> {
151156
/** The maximum number of tokens to generate. */
152-
MAX_TOKENS("max_tokens"),
157+
Parameter<Integer> MAX_TOKENS = () -> "max_tokens";
158+
153159
/** The sampling temperature. */
154-
TEMPERATURE("temperature"),
160+
Parameter<Number> TEMPERATURE = () -> "temperature";
161+
155162
/** The frequency penalty. */
156-
FREQUENCY_PENALTY("frequency_penalty"),
163+
Parameter<Number> FREQUENCY_PENALTY = () -> "frequency_penalty";
164+
157165
/** The presence penalty. */
158-
PRESENCE_PENALTY("presence_penalty"),
166+
Parameter<Number> PRESENCE_PENALTY = () -> "presence_penalty";
167+
159168
/** The maximum number of tokens for completion */
160-
MAX_COMPLETION_TOKENS("max_completion_tokens"),
169+
Parameter<Integer> MAX_COMPLETION_TOKENS = () -> "max_completion_tokens";
170+
161171
/** The probability mass to be considered . */
162-
TOP_P("top_p"),
172+
Parameter<Number> TOP_P = () -> "top_p";
173+
163174
/** The toggle to enable partial message delta. */
164-
STREAM("stream"),
165-
/** The options for streaming response. */
166-
STREAM_OPTIONS("stream_options"),
175+
Parameter<Boolean> STREAM = () -> "stream";
176+
177+
/** The options for streaming response. Only used in combination with STREAM = true. */
178+
Parameter<Map<String, Object>> STREAM_OPTIONS = () -> "stream_options";
179+
167180
/** The tokens where the API will stop generating further tokens. */
168-
STOP("stop"),
169-
/** The number of chat completion choices to generate for each input message. */
170-
N("n");
181+
Parameter<List<String>> STOP = () -> "stop";
171182

172-
private final String value;
183+
/** The number of chat completion choices to generate for each input message. */
184+
Parameter<Integer> N = () -> "n";
185+
186+
/**
187+
* The name of the parameter.
188+
*
189+
* @return the name of the parameter.
190+
*/
191+
@Nonnull
192+
String getName();
173193
}
174194
}

0 commit comments

Comments
 (0)