Skip to content

Commit 5be4100

Browse files
authored
[ML] Apply patch logic for Cohere V2 transport changes (#129993)
1 parent 2bc6284 commit 5be4100

File tree

6 files changed

+56
-18
lines changed

6 files changed

+56
-18
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ public CohereServiceSettings(StreamInput in) throws IOException {
183183
rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS;
184184
}
185185
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
186-
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
186+
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) {
187187
this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class);
188188
} else {
189189
this.apiVersion = CohereServiceSettings.CohereApiVersion.V1;
@@ -286,7 +286,7 @@ public void writeTo(StreamOutput out) throws IOException {
286286
rateLimitSettings.writeTo(out);
287287
}
288288
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
289-
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
289+
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) {
290290
out.writeEnum(apiVersion);
291291
}
292292
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public CohereCompletionServiceSettings(StreamInput in) throws IOException {
103103
modelId = in.readOptionalString();
104104
rateLimitSettings = new RateLimitSettings(in);
105105
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
106-
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
106+
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) {
107107
this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class);
108108
} else {
109109
this.apiVersion = CohereServiceSettings.CohereApiVersion.V1;
@@ -156,7 +156,7 @@ public void writeTo(StreamOutput out) throws IOException {
156156
out.writeOptionalString(modelId);
157157
rateLimitSettings.writeTo(out);
158158
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
159-
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
159+
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) {
160160
out.writeEnum(apiVersion);
161161
}
162162
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public CohereRerankServiceSettings(StreamInput in) throws IOException {
125125
}
126126

127127
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
128-
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
128+
|| in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) {
129129
this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class);
130130
} else {
131131
this.apiVersion = CohereServiceSettings.CohereApiVersion.V1;
@@ -207,7 +207,7 @@ public void writeTo(StreamOutput out) throws IOException {
207207
rateLimitSettings.writeTo(out);
208208
}
209209
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
210-
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
210+
|| out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) {
211211
out.writeEnum(apiVersion);
212212
}
213213
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere;
99

10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1012
import org.elasticsearch.common.Strings;
1113
import org.elasticsearch.common.ValidationException;
1214
import org.elasticsearch.common.io.stream.Writeable;
1315
import org.elasticsearch.core.Nullable;
1416
import org.elasticsearch.inference.SimilarityMeasure;
15-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1617
import org.elasticsearch.xcontent.XContentBuilder;
1718
import org.elasticsearch.xcontent.XContentFactory;
1819
import org.elasticsearch.xcontent.XContentType;
20+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1921
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2022
import org.elasticsearch.xpack.inference.services.ServiceFields;
2123
import org.elasticsearch.xpack.inference.services.ServiceUtils;
@@ -30,7 +32,7 @@
3032
import static org.hamcrest.Matchers.containsString;
3133
import static org.hamcrest.Matchers.is;
3234

33-
public class CohereServiceSettingsTests extends AbstractWireSerializingTestCase<CohereServiceSettings> {
35+
public class CohereServiceSettingsTests extends AbstractBWCWireSerializationTestCase<CohereServiceSettings> {
3436

3537
public static CohereServiceSettings createRandomWithNonNullUrl() {
3638
return createRandom(randomAlphaOfLength(15));
@@ -359,4 +361,22 @@ public static Map<String, Object> getServiceSettingsMap(@Nullable String url, @N
359361

360362
return map;
361363
}
364+
365+
@Override
366+
protected CohereServiceSettings mutateInstanceForVersion(CohereServiceSettings instance, TransportVersion version) {
367+
if (version.before(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
368+
&& (version.isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19) == false)) {
369+
return new CohereServiceSettings(
370+
instance.uri(),
371+
instance.similarity(),
372+
instance.dimensions(),
373+
instance.maxInputTokens(),
374+
instance.modelId(),
375+
instance.rateLimitSettings(),
376+
CohereServiceSettings.CohereApiVersion.V1
377+
);
378+
}
379+
380+
return instance;
381+
}
362382
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere.completion;
99

10+
import org.elasticsearch.TransportVersion;
11+
import org.elasticsearch.TransportVersions;
1012
import org.elasticsearch.common.Strings;
1113
import org.elasticsearch.common.io.stream.Writeable;
12-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
1314
import org.elasticsearch.xcontent.XContentBuilder;
1415
import org.elasticsearch.xcontent.XContentFactory;
1516
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1618
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
1719
import org.elasticsearch.xpack.inference.services.ServiceFields;
1820
import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
@@ -25,7 +27,7 @@
2527

2628
import static org.hamcrest.Matchers.is;
2729

28-
public class CohereCompletionServiceSettingsTests extends AbstractWireSerializingTestCase<CohereCompletionServiceSettings> {
30+
public class CohereCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase<CohereCompletionServiceSettings> {
2931

3032
public static CohereCompletionServiceSettings createRandom() {
3133
return new CohereCompletionServiceSettings(
@@ -110,4 +112,19 @@ protected CohereCompletionServiceSettings createTestInstance() {
110112
protected CohereCompletionServiceSettings mutateInstance(CohereCompletionServiceSettings instance) throws IOException {
111113
return randomValueOtherThan(instance, this::createTestInstance);
112114
}
115+
116+
@Override
117+
protected CohereCompletionServiceSettings mutateInstanceForVersion(CohereCompletionServiceSettings instance, TransportVersion version) {
118+
if (version.before(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
119+
&& (version.isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19) == false)) {
120+
return new CohereCompletionServiceSettings(
121+
instance.uri(),
122+
instance.modelId(),
123+
instance.rateLimitSettings(),
124+
CohereServiceSettings.CohereApiVersion.V1
125+
);
126+
}
127+
128+
return instance;
129+
}
113130
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,15 @@ protected CohereRerankServiceSettings mutateInstanceForVersion(CohereRerankServi
8888
CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS,
8989
CohereServiceSettings.CohereApiVersion.V1
9090
);
91-
} else if (version.before(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) {
92-
return new CohereRerankServiceSettings(
93-
instance.uri(),
94-
instance.modelId(),
95-
instance.rateLimitSettings(),
96-
CohereServiceSettings.CohereApiVersion.V1
97-
);
98-
}
91+
} else if (version.before(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)
92+
&& version.isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19) == false) {
93+
return new CohereRerankServiceSettings(
94+
instance.uri(),
95+
instance.modelId(),
96+
instance.rateLimitSettings(),
97+
CohereServiceSettings.CohereApiVersion.V1
98+
);
99+
}
99100
return instance;
100101
}
101102

0 commit comments

Comments
 (0)