|
2 | 2 |
|
3 | 3 | import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; |
4 | 4 | import java.util.LinkedHashMap; |
| 5 | +import java.util.List; |
5 | 6 | import java.util.Map; |
6 | 7 | import javax.annotation.Nonnull; |
7 | 8 | import javax.annotation.Nullable; |
8 | 9 | import lombok.AllArgsConstructor; |
9 | | -import lombok.RequiredArgsConstructor; |
10 | 10 | import lombok.Value; |
11 | 11 | import lombok.With; |
12 | 12 |
|
@@ -137,38 +137,58 @@ public OrchestrationAiModel withParam(@Nonnull final String key, @Nullable final |
137 | 137 | * |
138 | 138 | * @param param the parameter key. |
139 | 139 | * @param value the parameter value, nullable. |
| 140 | + * @param <ValueT> the parameter value type. |
140 | 141 | * @return A new model with the additional parameter. |
141 | 142 | */ |
142 | 143 | @Nonnull |
143 | | - public OrchestrationAiModel withParam( |
144 | | - @Nonnull final Parameter param, @Nullable final Object value) { |
145 | | - return withParam(param.value, value); |
| 144 | + public <ValueT> OrchestrationAiModel withParam( |
| 145 | + @Nonnull final Parameter<ValueT> param, @Nullable final ValueT value) { |
| 146 | + return withParam(param.getName(), value); |
146 | 147 | } |
147 | 148 |
|
148 | | - /** Parameter key for a model. */ |
149 | | - @RequiredArgsConstructor |
150 | | - public enum Parameter { |
| 149 | + /** |
| 150 | + * Parameter key for a model. |
| 151 | + * |
| 152 | + * @param <ValueT> the parameter value type. |
| 153 | + */ |
| 154 | + @FunctionalInterface |
| 155 | + public interface Parameter<ValueT> { |
151 | 156 | /** The maximum number of tokens to generate. */ |
152 | | - MAX_TOKENS("max_tokens"), |
| 157 | + Parameter<Integer> MAX_TOKENS = () -> "max_tokens"; |
| 158 | + |
153 | 159 | /** The sampling temperature. */ |
154 | | - TEMPERATURE("temperature"), |
| 160 | + Parameter<Number> TEMPERATURE = () -> "temperature"; |
| 161 | + |
155 | 162 | /** The frequency penalty. */ |
156 | | - FREQUENCY_PENALTY("frequency_penalty"), |
| 163 | + Parameter<Number> FREQUENCY_PENALTY = () -> "frequency_penalty"; |
| 164 | + |
157 | 165 | /** The presence penalty. */ |
158 | | - PRESENCE_PENALTY("presence_penalty"), |
| 166 | + Parameter<Number> PRESENCE_PENALTY = () -> "presence_penalty"; |
| 167 | + |
159 | 168 | /** The maximum number of tokens for completion */ |
160 | | - MAX_COMPLETION_TOKENS("max_completion_tokens"), |
| 169 | + Parameter<Integer> MAX_COMPLETION_TOKENS = () -> "max_completion_tokens"; |
| 170 | + |
161 | 171 | /** The probability mass to be considered . */ |
162 | | - TOP_P("top_p"), |
| 172 | + Parameter<Number> TOP_P = () -> "top_p"; |
| 173 | + |
163 | 174 | /** The toggle to enable partial message delta. */ |
164 | | - STREAM("stream"), |
165 | | - /** The options for streaming response. */ |
166 | | - STREAM_OPTIONS("stream_options"), |
| 175 | + Parameter<Boolean> STREAM = () -> "stream"; |
| 176 | + |
| 177 | + /** The options for streaming response. Only used in combination with STREAM = true. */ |
| 178 | + Parameter<Map<String, Object>> STREAM_OPTIONS = () -> "stream_options"; |
| 179 | + |
167 | 180 | /** The tokens where the API will stop generating further tokens. */ |
168 | | - STOP("stop"), |
169 | | - /** The number of chat completion choices to generate for each input message. */ |
170 | | - N("n"); |
| 181 | + Parameter<List<String>> STOP = () -> "stop"; |
171 | 182 |
|
172 | | - private final String value; |
| 183 | + /** The number of chat completion choices to generate for each input message. */ |
| 184 | + Parameter<Integer> N = () -> "n"; |
| 185 | + |
| 186 | + /** |
| 187 | + * The name of the parameter. |
| 188 | + * |
| 189 | + * @return the name of the parameter. |
| 190 | + */ |
| 191 | + @Nonnull |
| 192 | + String getName(); |
173 | 193 | } |
174 | 194 | } |
0 commit comments