Skip to content

Commit d77a9f3

Browse files
[Inference API] Implement updateServiceSettings() for Cohere service (elastic#145360)
* [Inference API] Implement updateServiceSettings() for Cohere service * [Inference API] Implement updateServiceSettings() for Cohere service Closes elastic/search-team#13622 Made-with: Cursor * Fix checkstyle LineLength in Cohere settings tests Wrap Javadoc lines to stay within the 140-column limit enforced by checkstyleTest for the inference plugin. Made-with: Cursor * [Inference API] Refactor Cohere service settings tests to streamline API version handling and update mutable fields * [Inference API] Update Cohere service settings tests to use random configuration parse context and improve variable naming
1 parent f6e01b1 commit d77a9f3

File tree

10 files changed

+987
-512
lines changed

10 files changed

+987
-512
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ public CohereModel buildModelFromConfigAndSecrets(ModelConfigurations config, Mo
176176
config.getInferenceEntityId(),
177177
config.getTaskType(),
178178
config.getService(),
179-
ConfigurationParseContext.PERSISTENT
179+
ConfigurationParseContext.REQUEST
180180
).createFromModelConfigurationsAndSecrets(config, secrets);
181181
}
182182

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java

Lines changed: 115 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.TransportVersion;
13+
import org.elasticsearch.common.Strings;
1314
import org.elasticsearch.common.ValidationException;
1415
import org.elasticsearch.common.io.stream.StreamInput;
1516
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -35,12 +36,16 @@
3536
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
3637
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
3738
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
38-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
3939
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
40+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
4041
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
42+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalUri;
4143
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
42-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
4344

45+
/**
46+
* Settings for the Cohere service.
47+
* This class encapsulates the configuration settings required to use Cohere models.
48+
*/
4449
public class CohereServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings {
4550

4651
public static final String NAME = "cohere_service_settings";
@@ -50,6 +55,9 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
5055

5156
private static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = TransportVersion.fromName("ml_inference_cohere_api_version");
5257

58+
/**
59+
* The API versions supported by the Cohere service.
60+
*/
5361
public enum CohereApiVersion {
5462
V1,
5563
V2;
@@ -64,43 +72,46 @@ public static CohereApiVersion fromString(String name) {
6472
// 10K requests a minute
6573
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
6674

75+
/**
76+
* Creates {@link CohereServiceSettings} from a map
77+
* @param map the map to parse
78+
* @param context the context in which the parsing is done
79+
* @return the created {@link CohereServiceSettings}
80+
* @throws ValidationException If there are validation errors in the provided settings.
81+
*/
6782
public static CohereServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
68-
ValidationException validationException = new ValidationException();
69-
70-
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
83+
var validationException = new ValidationException();
7184

72-
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
73-
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
74-
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
75-
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
76-
String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
77-
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
85+
var uri = extractOptionalUri(map, URL, validationException);
86+
var similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
87+
var dimensions = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
88+
var maxInputTokens = extractOptionalPositiveInteger(
7889
map,
79-
DEFAULT_RATE_LIMIT_SETTINGS,
80-
validationException,
81-
CohereService.NAME,
82-
context
90+
MAX_INPUT_TOKENS,
91+
ModelConfigurations.SERVICE_SETTINGS,
92+
validationException
8393
);
84-
85-
String modelId = extractOptionalString(map, ServiceFields.MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
86-
87-
if (context == ConfigurationParseContext.REQUEST && oldModelId != null) {
88-
logger.info("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead.");
89-
}
90-
91-
var resolvedModelId = modelId(oldModelId, modelId);
94+
var modelId = extractModelId(map, validationException, context);
9295
var apiVersion = apiVersionFromMap(map, context, validationException);
93-
if (apiVersion == CohereApiVersion.V2) {
94-
if (resolvedModelId == null) {
95-
validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API);
96-
}
96+
if (apiVersion == CohereApiVersion.V2 && modelId == null) {
97+
validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API);
9798
}
9899

100+
var rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException, CohereService.NAME, context);
101+
99102
validationException.throwIfValidationErrorsExist();
100103

101-
return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, resolvedModelId, rateLimitSettings, apiVersion);
104+
return new CohereServiceSettings(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion);
102105
}
103106

107+
/**
108+
* Extracts the Cohere API version from the provided map based on the given context.
109+
*
110+
* @param map the map containing the settings
111+
* @param context the context for parsing configuration settings
112+
* @param validationException the validation exception to collect errors
113+
* @return the extracted Cohere API version
114+
*/
104115
public static CohereApiVersion apiVersionFromMap(
105116
Map<String, Object> map,
106117
ConfigurationParseContext context,
@@ -127,8 +138,31 @@ public static CohereApiVersion apiVersionFromMap(
127138
};
128139
}
129140

130-
private static String modelId(@Nullable String model, @Nullable String modelId) {
131-
return modelId != null ? modelId : model;
141+
private static String extractModelId(
142+
Map<String, Object> serviceSettings,
143+
ValidationException validationException,
144+
ConfigurationParseContext context
145+
) {
146+
var extractedOldModelId = extractOptionalString(
147+
serviceSettings,
148+
OLD_MODEL_ID_FIELD,
149+
ModelConfigurations.SERVICE_SETTINGS,
150+
validationException
151+
);
152+
if (context == ConfigurationParseContext.REQUEST && extractedOldModelId != null) {
153+
logger.info("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead.");
154+
}
155+
var extractedModelId = extractOptionalString(
156+
serviceSettings,
157+
ServiceFields.MODEL_ID,
158+
ModelConfigurations.SERVICE_SETTINGS,
159+
validationException
160+
);
161+
return selectModelId(extractedOldModelId, extractedModelId);
162+
}
163+
164+
private static String selectModelId(@Nullable String oldModelId, @Nullable String newModelId) {
165+
return newModelId != null ? newModelId : oldModelId;
132166
}
133167

134168
private final URI uri;
@@ -139,6 +173,17 @@ private static String modelId(@Nullable String model, @Nullable String modelId)
139173
private final RateLimitSettings rateLimitSettings;
140174
private final CohereApiVersion apiVersion;
141175

176+
/**
177+
* Constructs a new {@link CohereServiceSettings} instance.
178+
*
179+
* @param uri the URI of the Cohere service
180+
* @param similarity the similarity measure to use
181+
* @param dimensions the number of dimensions for embeddings
182+
* @param maxInputTokens the maximum number of input tokens
183+
* @param modelId the model identifier
184+
* @param rateLimitSettings the rate limit settings
185+
* @param apiVersion the Cohere API version
186+
*/
142187
public CohereServiceSettings(
143188
@Nullable URI uri,
144189
@Nullable SimilarityMeasure similarity,
@@ -169,6 +214,12 @@ public CohereServiceSettings(
169214
this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion);
170215
}
171216

217+
/**
218+
* Constructs a new {@link CohereServiceSettings} instance from a {@link StreamInput}.
219+
*
220+
* @param in the stream input to read from
221+
* @throws IOException if an I/O error occurs
222+
*/
172223
public CohereServiceSettings(StreamInput in) throws IOException {
173224
uri = createOptionalUri(in.readOptionalString());
174225
similarity = in.readOptionalEnum(SimilarityMeasure.class);
@@ -183,7 +234,7 @@ public CohereServiceSettings(StreamInput in) throws IOException {
183234
}
184235
}
185236

186-
// should only be used for testing, public because it's accessed outside of the package
237+
// should only be used for testing, public because it's accessed outside the package
187238
public CohereServiceSettings(CohereApiVersion apiVersion) {
188239
this((URI) null, null, null, null, null, null, apiVersion);
189240
}
@@ -221,6 +272,34 @@ public String modelId() {
221272
return modelId;
222273
}
223274

275+
public CohereServiceSettings updateCommonServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
276+
277+
var extractedMaxInputTokens = extractOptionalPositiveInteger(
278+
serviceSettings,
279+
MAX_INPUT_TOKENS,
280+
ModelConfigurations.SERVICE_SETTINGS,
281+
validationException
282+
);
283+
284+
var extractedRateLimitSettings = RateLimitSettings.of(
285+
serviceSettings,
286+
this.rateLimitSettings,
287+
validationException,
288+
CohereService.NAME,
289+
ConfigurationParseContext.REQUEST
290+
);
291+
292+
return new CohereServiceSettings(
293+
this.uri,
294+
this.similarity,
295+
this.dimensions,
296+
extractedMaxInputTokens != null ? extractedMaxInputTokens : this.maxInputTokens,
297+
this.modelId,
298+
extractedRateLimitSettings,
299+
this.apiVersion
300+
);
301+
}
302+
224303
@Override
225304
public String getWriteableName() {
226305
return NAME;
@@ -281,6 +360,11 @@ public void writeTo(StreamOutput out) throws IOException {
281360
}
282361
}
283362

363+
@Override
364+
public String toString() {
365+
return Strings.toString(this);
366+
}
367+
284368
@Override
285369
public boolean equals(Object o) {
286370
if (this == o) return true;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.cohere.completion;
99

1010
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.common.Strings;
1112
import org.elasticsearch.common.ValidationException;
1213
import org.elasticsearch.common.io.stream.StreamInput;
1314
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -29,13 +30,17 @@
2930

3031
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
3132
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
32-
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
3333
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
3434
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
35+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalUri;
3536
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.API_VERSION;
3637
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.MODEL_REQUIRED_FOR_V2_API;
3738
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.apiVersionFromMap;
3839

40+
/**
41+
* Settings for the Cohere completion service.
42+
* This class encapsulates the configuration settings required to use Cohere models for generating completions.
43+
*/
3944
public class CohereCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings {
4045

4146
public static final String NAME = "cohere_completion_service_settings";
@@ -45,24 +50,23 @@ public class CohereCompletionServiceSettings extends FilteredXContentObject impl
4550
// 10K requests per minute
4651
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
4752

53+
/**
54+
* Creates an instance of {@link CohereCompletionServiceSettings} from a map of settings.
55+
*
56+
* @param map The map containing the settings.
57+
* @param context The context for configuration parsing.
58+
* @return the created {@link CohereCompletionServiceSettings}.
59+
* @throws ValidationException If there are validation errors in the provided settings.
60+
*/
4861
public static CohereCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
49-
ValidationException validationException = new ValidationException();
62+
var validationException = new ValidationException();
5063

51-
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
52-
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
53-
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
54-
map,
55-
DEFAULT_RATE_LIMIT_SETTINGS,
56-
validationException,
57-
CohereService.NAME,
58-
context
59-
);
60-
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
64+
var uri = extractOptionalUri(map, URL, validationException);
65+
var rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException, CohereService.NAME, context);
66+
var modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
6167
var apiVersion = apiVersionFromMap(map, context, validationException);
62-
if (apiVersion == CohereServiceSettings.CohereApiVersion.V2) {
63-
if (modelId == null) {
64-
validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API);
65-
}
68+
if (apiVersion == CohereServiceSettings.CohereApiVersion.V2 && modelId == null) {
69+
validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API);
6670
}
6771

6872
validationException.throwIfValidationErrorsExist();
@@ -96,6 +100,11 @@ public CohereCompletionServiceSettings(
96100
this(createOptionalUri(url), modelId, rateLimitSettings, apiVersion);
97101
}
98102

103+
/**
104+
* Creates {@link CohereCompletionServiceSettings} from a {@link StreamInput}.
105+
* @param in the stream input
106+
* @throws IOException if an I/O exception occurs
107+
*/
99108
public CohereCompletionServiceSettings(StreamInput in) throws IOException {
100109
uri = createOptionalUri(in.readOptionalString());
101110
modelId = in.readOptionalString();
@@ -125,6 +134,23 @@ public String modelId() {
125134
return modelId;
126135
}
127136

137+
@Override
138+
public CohereCompletionServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
139+
var validationException = new ValidationException();
140+
141+
var extractedRateLimitSettings = RateLimitSettings.of(
142+
serviceSettings,
143+
this.rateLimitSettings,
144+
validationException,
145+
CohereService.NAME,
146+
ConfigurationParseContext.REQUEST
147+
);
148+
149+
validationException.throwIfValidationErrorsExist();
150+
151+
return new CohereCompletionServiceSettings(this.uri, this.modelId, extractedRateLimitSettings, this.apiVersion);
152+
}
153+
128154
@Override
129155
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
130156
builder.startObject();
@@ -171,6 +197,11 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
171197
return builder;
172198
}
173199

200+
@Override
201+
public String toString() {
202+
return Strings.toString(this);
203+
}
204+
174205
@Override
175206
public boolean equals(Object object) {
176207
if (this == object) return true;

0 commit comments

Comments
 (0)