1010import org .apache .logging .log4j .LogManager ;
1111import org .apache .logging .log4j .Logger ;
1212import org .elasticsearch .TransportVersion ;
13+ import org .elasticsearch .common .Strings ;
1314import org .elasticsearch .common .ValidationException ;
1415import org .elasticsearch .common .io .stream .StreamInput ;
1516import org .elasticsearch .common .io .stream .StreamOutput ;
3536import static org .elasticsearch .xpack .inference .services .ServiceFields .MAX_INPUT_TOKENS ;
3637import static org .elasticsearch .xpack .inference .services .ServiceFields .SIMILARITY ;
3738import static org .elasticsearch .xpack .inference .services .ServiceFields .URL ;
38- import static org .elasticsearch .xpack .inference .services .ServiceUtils .convertToUri ;
3939import static org .elasticsearch .xpack .inference .services .ServiceUtils .createOptionalUri ;
40+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalPositiveInteger ;
4041import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalString ;
42+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalUri ;
4143import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractSimilarity ;
42- import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeAsType ;
4344
45+ /**
46+ * Settings for the Cohere service.
47+ * This class encapsulates the configuration settings required to use Cohere models.
48+ */
4449public class CohereServiceSettings extends FilteredXContentObject implements ServiceSettings , CohereRateLimitServiceSettings {
4550
4651 public static final String NAME = "cohere_service_settings" ;
@@ -50,6 +55,9 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
5055
5156 private static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = TransportVersion .fromName ("ml_inference_cohere_api_version" );
5257
58+ /**
59+ * The API versions supported by the Cohere service.
60+ */
5361 public enum CohereApiVersion {
5462 V1 ,
5563 V2 ;
@@ -64,43 +72,46 @@ public static CohereApiVersion fromString(String name) {
6472 // 10K requests a minute
6573 public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings (10_000 );
6674
75+ /**
76+ * Creates {@link CohereServiceSettings} from a map
77+ * @param map the map to parse
78+ * @param context the context in which the parsing is done
79+ * @return the created {@link CohereServiceSettings}
80+ * @throws ValidationException If there are validation errors in the provided settings.
81+ */
6782 public static CohereServiceSettings fromMap (Map <String , Object > map , ConfigurationParseContext context ) {
68- ValidationException validationException = new ValidationException ();
69-
70- String url = extractOptionalString (map , URL , ModelConfigurations .SERVICE_SETTINGS , validationException );
83+ var validationException = new ValidationException ();
7184
72- SimilarityMeasure similarity = extractSimilarity (map , ModelConfigurations .SERVICE_SETTINGS , validationException );
73- Integer dims = removeAsType (map , DIMENSIONS , Integer .class );
74- Integer maxInputTokens = removeAsType (map , MAX_INPUT_TOKENS , Integer .class );
75- URI uri = convertToUri (url , URL , ModelConfigurations .SERVICE_SETTINGS , validationException );
76- String oldModelId = extractOptionalString (map , OLD_MODEL_ID_FIELD , ModelConfigurations .SERVICE_SETTINGS , validationException );
77- RateLimitSettings rateLimitSettings = RateLimitSettings .of (
85+ var uri = extractOptionalUri (map , URL , validationException );
86+ var similarity = extractSimilarity (map , ModelConfigurations .SERVICE_SETTINGS , validationException );
87+ var dimensions = extractOptionalPositiveInteger (map , DIMENSIONS , ModelConfigurations .SERVICE_SETTINGS , validationException );
88+ var maxInputTokens = extractOptionalPositiveInteger (
7889 map ,
79- DEFAULT_RATE_LIMIT_SETTINGS ,
80- validationException ,
81- CohereService .NAME ,
82- context
90+ MAX_INPUT_TOKENS ,
91+ ModelConfigurations .SERVICE_SETTINGS ,
92+ validationException
8393 );
84-
85- String modelId = extractOptionalString (map , ServiceFields .MODEL_ID , ModelConfigurations .SERVICE_SETTINGS , validationException );
86-
87- if (context == ConfigurationParseContext .REQUEST && oldModelId != null ) {
88- logger .info ("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead." );
89- }
90-
91- var resolvedModelId = modelId (oldModelId , modelId );
94+ var modelId = extractModelId (map , validationException , context );
9295 var apiVersion = apiVersionFromMap (map , context , validationException );
93- if (apiVersion == CohereApiVersion .V2 ) {
94- if (resolvedModelId == null ) {
95- validationException .addValidationError (MODEL_REQUIRED_FOR_V2_API );
96- }
96+ if (apiVersion == CohereApiVersion .V2 && modelId == null ) {
97+ validationException .addValidationError (MODEL_REQUIRED_FOR_V2_API );
9798 }
9899
100+ var rateLimitSettings = RateLimitSettings .of (map , DEFAULT_RATE_LIMIT_SETTINGS , validationException , CohereService .NAME , context );
101+
99102 validationException .throwIfValidationErrorsExist ();
100103
101- return new CohereServiceSettings (uri , similarity , dims , maxInputTokens , resolvedModelId , rateLimitSettings , apiVersion );
104+ return new CohereServiceSettings (uri , similarity , dimensions , maxInputTokens , modelId , rateLimitSettings , apiVersion );
102105 }
103106
107+ /**
108+ * Extracts the Cohere API version from the provided map based on the given context.
109+ *
110+ * @param map the map containing the settings
111+ * @param context the context for parsing configuration settings
112+ * @param validationException the validation exception to collect errors
113+ * @return the extracted Cohere API version
114+ */
104115 public static CohereApiVersion apiVersionFromMap (
105116 Map <String , Object > map ,
106117 ConfigurationParseContext context ,
@@ -127,8 +138,31 @@ public static CohereApiVersion apiVersionFromMap(
127138 };
128139 }
129140
130- private static String modelId (@ Nullable String model , @ Nullable String modelId ) {
131- return modelId != null ? modelId : model ;
141+ private static String extractModelId (
142+ Map <String , Object > serviceSettings ,
143+ ValidationException validationException ,
144+ ConfigurationParseContext context
145+ ) {
146+ var extractedOldModelId = extractOptionalString (
147+ serviceSettings ,
148+ OLD_MODEL_ID_FIELD ,
149+ ModelConfigurations .SERVICE_SETTINGS ,
150+ validationException
151+ );
152+ if (context == ConfigurationParseContext .REQUEST && extractedOldModelId != null ) {
153+ logger .info ("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead." );
154+ }
155+ var extractedModelId = extractOptionalString (
156+ serviceSettings ,
157+ ServiceFields .MODEL_ID ,
158+ ModelConfigurations .SERVICE_SETTINGS ,
159+ validationException
160+ );
161+ return selectModelId (extractedOldModelId , extractedModelId );
162+ }
163+
164+ private static String selectModelId (@ Nullable String oldModelId , @ Nullable String newModelId ) {
165+ return newModelId != null ? newModelId : oldModelId ;
132166 }
133167
134168 private final URI uri ;
@@ -139,6 +173,17 @@ private static String modelId(@Nullable String model, @Nullable String modelId)
139173 private final RateLimitSettings rateLimitSettings ;
140174 private final CohereApiVersion apiVersion ;
141175
176+ /**
177+ * Constructs a new {@link CohereServiceSettings} instance.
178+ *
179+ * @param uri the URI of the Cohere service
180+ * @param similarity the similarity measure to use
181+ * @param dimensions the number of dimensions for embeddings
182+ * @param maxInputTokens the maximum number of input tokens
183+ * @param modelId the model identifier
184+ * @param rateLimitSettings the rate limit settings
185+ * @param apiVersion the Cohere API version
186+ */
142187 public CohereServiceSettings (
143188 @ Nullable URI uri ,
144189 @ Nullable SimilarityMeasure similarity ,
@@ -169,6 +214,12 @@ public CohereServiceSettings(
169214 this (createOptionalUri (url ), similarity , dimensions , maxInputTokens , modelId , rateLimitSettings , apiVersion );
170215 }
171216
217+ /**
218+ * Constructs a new {@link CohereServiceSettings} instance from a {@link StreamInput}.
219+ *
220+ * @param in the stream input to read from
221+ * @throws IOException if an I/O error occurs
222+ */
172223 public CohereServiceSettings (StreamInput in ) throws IOException {
173224 uri = createOptionalUri (in .readOptionalString ());
174225 similarity = in .readOptionalEnum (SimilarityMeasure .class );
@@ -183,7 +234,7 @@ public CohereServiceSettings(StreamInput in) throws IOException {
183234 }
184235 }
185236
186- // should only be used for testing, public because it's accessed outside of the package
237+ // should only be used for testing, public because it's accessed outside the package
187238 public CohereServiceSettings (CohereApiVersion apiVersion ) {
188239 this ((URI ) null , null , null , null , null , null , apiVersion );
189240 }
@@ -221,6 +272,34 @@ public String modelId() {
221272 return modelId ;
222273 }
223274
275+ public CohereServiceSettings updateCommonServiceSettings (Map <String , Object > serviceSettings , ValidationException validationException ) {
276+
277+ var extractedMaxInputTokens = extractOptionalPositiveInteger (
278+ serviceSettings ,
279+ MAX_INPUT_TOKENS ,
280+ ModelConfigurations .SERVICE_SETTINGS ,
281+ validationException
282+ );
283+
284+ var extractedRateLimitSettings = RateLimitSettings .of (
285+ serviceSettings ,
286+ this .rateLimitSettings ,
287+ validationException ,
288+ CohereService .NAME ,
289+ ConfigurationParseContext .REQUEST
290+ );
291+
292+ return new CohereServiceSettings (
293+ this .uri ,
294+ this .similarity ,
295+ this .dimensions ,
296+ extractedMaxInputTokens != null ? extractedMaxInputTokens : this .maxInputTokens ,
297+ this .modelId ,
298+ extractedRateLimitSettings ,
299+ this .apiVersion
300+ );
301+ }
302+
224303 @ Override
225304 public String getWriteableName () {
226305 return NAME ;
@@ -281,6 +360,11 @@ public void writeTo(StreamOutput out) throws IOException {
281360 }
282361 }
283362
363+ @ Override
364+ public String toString () {
365+ return Strings .toString (this );
366+ }
367+
284368 @ Override
285369 public boolean equals (Object o ) {
286370 if (this == o ) return true ;
0 commit comments