Skip to content

Commit 5eca6b8

Browse files
timgreinalbertzaharovits
authored andcommitted
[Inference API] Remove unused Cohere rerank service settings fields in a BWC way (#110427)
1 parent bb99c78 commit 5eca6b8

File tree

8 files changed

+164
-53
lines changed

8 files changed

+164
-53
lines changed

docs/changelog/110427.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 110427
2+
summary: "[Inference API] Remove unused Cohere rerank service settings fields in a\
3+
\ BWC way"
4+
area: Machine Learning
5+
type: bug
6+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ static TransportVersion def(int id) {
212212
public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_703_00_0);
213213
public static final TransportVersion INFERENCE_ADAPTIVE_ALLOCATIONS = def(8_704_00_0);
214214
public static final TransportVersion INDEX_REQUEST_UPDATE_BY_SCRIPT_ORIGIN = def(8_705_00_0);
215+
public static final TransportVersion ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED = def(8_706_00_0);
215216

216217
/*
217218
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereRerankAction.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ public class CohereRerankAction implements ExecutableAction {
3232
public CohereRerankAction(Sender sender, CohereRerankModel model, ThreadPool threadPool) {
3333
Objects.requireNonNull(model);
3434
this.sender = Objects.requireNonNull(sender);
35-
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
36-
model.getServiceSettings().getCommonSettings().uri(),
37-
"Cohere rerank"
38-
);
35+
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.getServiceSettings().uri(), "Cohere rerank");
3936
requestCreator = CohereRerankRequestManager.of(model, threadPool);
4037
}
4138

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public CohereRerankRequest(String query, List<String> input, CohereRerankModel m
3939
this.input = Objects.requireNonNull(input);
4040
this.query = Objects.requireNonNull(query);
4141
taskSettings = model.getTaskSettings();
42-
this.model = model.getServiceSettings().getCommonSettings().modelId();
42+
this.model = model.getServiceSettings().modelId();
4343
inferenceEntityId = model.getInferenceEntityId();
4444
}
4545

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
4646
private static final Logger logger = LogManager.getLogger(CohereServiceSettings.class);
4747
// Production key rate limits for all endpoints: https://docs.cohere.com/docs/going-live#production-key-specifications
4848
// 10K requests a minute
49-
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
49+
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
5050

5151
public static CohereServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
5252
ValidationException validationException = new ValidationException();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public CohereRerankModel(
5959
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
6060
new ModelSecrets(secretSettings),
6161
secretSettings,
62-
serviceSettings.getCommonSettings()
62+
serviceSettings
6363
);
6464
}
6565

@@ -100,6 +100,6 @@ public ExecutableAction accept(CohereActionVisitor visitor, Map<String, Object>
100100

101101
@Override
102102
public URI uri() {
103-
return getServiceSettings().getCommonSettings().uri();
103+
return getServiceSettings().uri();
104104
}
105105
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java

Lines changed: 118 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,118 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere.rerank;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.elasticsearch.TransportVersion;
1113
import org.elasticsearch.TransportVersions;
1214
import org.elasticsearch.common.ValidationException;
1315
import org.elasticsearch.common.io.stream.StreamInput;
1416
import org.elasticsearch.common.io.stream.StreamOutput;
17+
import org.elasticsearch.core.Nullable;
18+
import org.elasticsearch.inference.ModelConfigurations;
1519
import org.elasticsearch.inference.ServiceSettings;
20+
import org.elasticsearch.inference.SimilarityMeasure;
1621
import org.elasticsearch.xcontent.XContentBuilder;
1722
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
18-
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
23+
import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings;
24+
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
1925
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
26+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2027

2128
import java.io.IOException;
29+
import java.net.URI;
2230
import java.util.Map;
2331
import java.util.Objects;
2432

25-
public class CohereRerankServiceSettings extends FilteredXContentObject implements ServiceSettings {
33+
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
34+
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
35+
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
36+
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
37+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
38+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
39+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
40+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
41+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
42+
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS;
43+
44+
public class CohereRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings {
2645
public static final String NAME = "cohere_rerank_service_settings";
2746

28-
public static CohereRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext parseContext) {
47+
private static final Logger logger = LogManager.getLogger(CohereRerankServiceSettings.class);
48+
49+
public static CohereRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
2950
ValidationException validationException = new ValidationException();
30-
var commonServiceSettings = CohereServiceSettings.fromMap(map, parseContext);
51+
52+
String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
53+
54+
// We need to extract/remove those fields to avoid unknown service settings errors
55+
extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
56+
removeAsType(map, DIMENSIONS, Integer.class);
57+
removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
58+
59+
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
60+
String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
61+
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
62+
map,
63+
DEFAULT_RATE_LIMIT_SETTINGS,
64+
validationException,
65+
CohereService.NAME,
66+
context
67+
);
3168

3269
if (validationException.validationErrors().isEmpty() == false) {
3370
throw validationException;
3471
}
3572

36-
return new CohereRerankServiceSettings(commonServiceSettings);
73+
return new CohereRerankServiceSettings(uri, modelId, rateLimitSettings);
3774
}
3875

39-
private final CohereServiceSettings commonSettings;
76+
private final URI uri;
77+
78+
private final String modelId;
79+
80+
private final RateLimitSettings rateLimitSettings;
81+
82+
public CohereRerankServiceSettings(@Nullable URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) {
83+
this.uri = uri;
84+
this.modelId = modelId;
85+
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
86+
}
4087

41-
public CohereRerankServiceSettings(CohereServiceSettings commonSettings) {
42-
this.commonSettings = commonSettings;
88+
public CohereRerankServiceSettings(@Nullable String url, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) {
89+
this(createOptionalUri(url), modelId, rateLimitSettings);
4390
}
4491

4592
public CohereRerankServiceSettings(StreamInput in) throws IOException {
46-
commonSettings = new CohereServiceSettings(in);
93+
this.uri = createOptionalUri(in.readOptionalString());
94+
95+
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED)) {
96+
// An older node sends these fields, so we need to skip them to progress through the serialized data
97+
in.readOptionalEnum(SimilarityMeasure.class);
98+
in.readOptionalVInt();
99+
in.readOptionalVInt();
100+
}
101+
102+
this.modelId = in.readOptionalString();
103+
104+
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED)) {
105+
this.rateLimitSettings = new RateLimitSettings(in);
106+
} else {
107+
this.rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS;
108+
}
109+
}
110+
111+
public URI uri() {
112+
return uri;
113+
}
114+
115+
public String modelId() {
116+
return modelId;
117+
}
118+
119+
@Override
120+
public RateLimitSettings rateLimitSettings() {
121+
return rateLimitSettings;
47122
}
48123

49124
@Override
@@ -55,15 +130,23 @@ public String getWriteableName() {
55130
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
56131
builder.startObject();
57132

58-
commonSettings.toXContentFragment(builder, params);
133+
toXContentFragmentOfExposedFields(builder, params);
59134

60135
builder.endObject();
61136
return builder;
62137
}
63138

64139
@Override
65140
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
66-
commonSettings.toXContentFragmentOfExposedFields(builder, params);
141+
if (uri != null) {
142+
builder.field(URL, uri.toString());
143+
}
144+
145+
if (modelId != null) {
146+
builder.field(MODEL_ID, modelId);
147+
}
148+
149+
rateLimitSettings.toXContent(builder, params);
67150

68151
return builder;
69152
}
@@ -75,23 +158,36 @@ public TransportVersion getMinimalSupportedVersion() {
75158

76159
@Override
77160
public void writeTo(StreamOutput out) throws IOException {
78-
commonSettings.writeTo(out);
161+
var uriToWrite = uri != null ? uri.toString() : null;
162+
out.writeOptionalString(uriToWrite);
163+
164+
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED)) {
165+
// An old node expects this data to be present, so we need to send at least the booleans
166+
// indicating that the fields are not set
167+
out.writeOptionalEnum(null);
168+
out.writeOptionalVInt(null);
169+
out.writeOptionalVInt(null);
170+
}
171+
172+
out.writeOptionalString(modelId);
173+
174+
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED)) {
175+
rateLimitSettings.writeTo(out);
176+
}
79177
}
80178

81179
@Override
82-
public boolean equals(Object o) {
83-
if (this == o) return true;
84-
if (o == null || getClass() != o.getClass()) return false;
85-
CohereRerankServiceSettings that = (CohereRerankServiceSettings) o;
86-
return Objects.equals(commonSettings, that.commonSettings);
180+
public boolean equals(Object object) {
181+
if (this == object) return true;
182+
if (object == null || getClass() != object.getClass()) return false;
183+
CohereRerankServiceSettings that = (CohereRerankServiceSettings) object;
184+
return Objects.equals(uri, that.uri)
185+
&& Objects.equals(modelId, that.modelId)
186+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
87187
}
88188

89189
@Override
90190
public int hashCode() {
91-
return Objects.hash(commonSettings);
92-
}
93-
94-
public CohereServiceSettings getCommonSettings() {
95-
return commonSettings;
191+
return Objects.hash(uri, modelId, rateLimitSettings);
96192
}
97193
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,48 +7,58 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere.rerank;
99

10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1012
import org.elasticsearch.common.Strings;
11-
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1213
import org.elasticsearch.common.io.stream.Writeable;
1314
import org.elasticsearch.core.Nullable;
14-
import org.elasticsearch.inference.SimilarityMeasure;
15-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1615
import org.elasticsearch.xcontent.XContentBuilder;
1716
import org.elasticsearch.xcontent.XContentFactory;
1817
import org.elasticsearch.xcontent.XContentType;
19-
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
20-
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
18+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
2119
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
2220
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettingsTests;
2321
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
22+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
2423

2524
import java.io.IOException;
26-
import java.util.ArrayList;
2725
import java.util.HashMap;
28-
import java.util.List;
2926
import java.util.Map;
3027

31-
import static org.hamcrest.Matchers.is;
28+
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
3229

33-
public class CohereRerankServiceSettingsTests extends AbstractWireSerializingTestCase<CohereRerankServiceSettings> {
30+
public class CohereRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase<CohereRerankServiceSettings> {
3431
public static CohereRerankServiceSettings createRandom() {
35-
var commonSettings = CohereServiceSettingsTests.createRandom();
32+
return createRandom(randomFrom(new RateLimitSettings[] { null, RateLimitSettingsTests.createRandom() }));
33+
}
3634

37-
return new CohereRerankServiceSettings(commonSettings);
35+
public static CohereRerankServiceSettings createRandom(@Nullable RateLimitSettings rateLimitSettings) {
36+
return new CohereRerankServiceSettings(
37+
randomFrom(new String[] { null, Strings.format("http://%s.com", randomAlphaOfLength(8)) }),
38+
randomFrom(new String[] { null, randomAlphaOfLength(10) }),
39+
rateLimitSettings
40+
);
3841
}
3942

4043
public void testToXContent_WritesAllValues() throws IOException {
41-
var serviceSettings = new CohereRerankServiceSettings(
42-
new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3))
43-
);
44+
var url = "http://www.abc.com";
45+
var model = "model";
46+
47+
var serviceSettings = new CohereRerankServiceSettings(url, model, null);
4448

4549
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
4650
serviceSettings.toXContent(builder, null);
4751
String xContentResult = Strings.toString(builder);
48-
// TODO we probably shouldn't allow configuring these fields for reranking
49-
assertThat(xContentResult, is("""
50-
{"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """
51-
"rate_limit":{"requests_per_minute":3}}"""));
52+
53+
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
54+
{
55+
"url":"http://www.abc.com",
56+
"model_id":"model",
57+
"rate_limit": {
58+
"requests_per_minute": 10000
59+
}
60+
}
61+
"""));
5262
}
5363

5464
@Override
@@ -67,11 +77,12 @@ protected CohereRerankServiceSettings mutateInstance(CohereRerankServiceSettings
6777
}
6878

6979
@Override
70-
protected NamedWriteableRegistry getNamedWriteableRegistry() {
71-
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
72-
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
73-
entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables());
74-
return new NamedWriteableRegistry(entries);
80+
protected CohereRerankServiceSettings mutateInstanceForVersion(CohereRerankServiceSettings instance, TransportVersion version) {
81+
if (version.before(TransportVersions.ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED)) {
82+
// We always default to the same rate limit settings, if a node is on a version before rate limits were introduced
83+
return new CohereRerankServiceSettings(instance.uri(), instance.modelId(), CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS);
84+
}
85+
return instance;
7586
}
7687

7788
public static Map<String, Object> getServiceSettingsMap(@Nullable String url, @Nullable String model) {

0 commit comments

Comments
 (0)