99
1010import com .carrotsearch .randomizedtesting .annotations .Name ;
1111
12+ import org .elasticsearch .client .ResponseException ;
1213import org .elasticsearch .common .Strings ;
1314import org .elasticsearch .inference .TaskType ;
15+ import org .elasticsearch .test .http .MockRequest ;
1416import org .elasticsearch .test .http .MockResponse ;
1517import org .elasticsearch .test .http .MockWebServer ;
1618import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingType ;
2426
2527import static org .hamcrest .Matchers .anEmptyMap ;
2628import static org .hamcrest .Matchers .anyOf ;
29+ import static org .hamcrest .Matchers .containsString ;
2730import static org .hamcrest .Matchers .empty ;
2831import static org .hamcrest .Matchers .hasEntry ;
2932import static org .hamcrest .Matchers .hasSize ;
@@ -36,10 +39,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3639 // TODO: replace with proper test features
3740 private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0" ;
3841 private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0" ;
42+ private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2" ;
3943
4044 private static MockWebServer cohereEmbeddingsServer ;
4145 private static MockWebServer cohereRerankServer ;
4246
47+ private enum ApiVersion {
48+ V1 ,
49+ V2
50+ }
51+
4352 public CohereServiceUpgradeIT (@ Name ("upgradedNodes" ) int upgradedNodes ) {
4453 super (upgradedNodes );
4554 }
@@ -62,15 +71,18 @@ public static void shutdown() {
6271 @ SuppressWarnings ("unchecked" )
6372 public void testCohereEmbeddings () throws IOException {
6473 var embeddingsSupported = oldClusterHasFeature (COHERE_EMBEDDINGS_ADDED_TEST_FEATURE );
65- String oldClusterEndpointIdentifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
6674 assumeTrue ("Cohere embedding service supported" , embeddingsSupported );
6775
76+ String oldClusterEndpointIdentifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
77+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion .V2 : ApiVersion .V1 ;
78+
6879 final String oldClusterIdInt8 = "old-cluster-embeddings-int8" ;
6980 final String oldClusterIdFloat = "old-cluster-embeddings-float" ;
7081
7182 var testTaskType = TaskType .TEXT_EMBEDDING ;
7283
7384 if (isOldCluster ()) {
85+
7486 // queue a response as PUT will call the service
7587 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
7688 put (oldClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
@@ -128,13 +140,17 @@ public void testCohereEmbeddings() throws IOException {
128140
129141 // Inference on old cluster models
130142 assertEmbeddingInference (oldClusterIdInt8 , CohereEmbeddingType .BYTE );
143+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , oldClusterApiVersion );
131144 assertEmbeddingInference (oldClusterIdFloat , CohereEmbeddingType .FLOAT );
145+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , oldClusterApiVersion );
132146
133147 {
134148 final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte" ;
135149
150+ // new endpoints use the V2 API
136151 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
137152 put (upgradedClusterIdByte , embeddingConfigByte (getUrl (cohereEmbeddingsServer )), testTaskType );
153+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
138154
139155 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdByte ).get ("endpoints" );
140156 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
@@ -146,34 +162,70 @@ public void testCohereEmbeddings() throws IOException {
146162 {
147163 final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8" ;
148164
165+ // new endpoints use the V2 API
149166 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
150167 put (upgradedClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
168+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
151169
152170 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdInt8 ).get ("endpoints" );
153171 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
154172 assertThat (serviceSettings , hasEntry ("embedding_type" , "byte" )); // int8 rewritten to byte
155173
156174 assertEmbeddingInference (upgradedClusterIdInt8 , CohereEmbeddingType .INT8 );
175+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
157176 delete (upgradedClusterIdInt8 );
158177 }
159178 {
160179 final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float" ;
161180 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseFloat ()));
162181 put (upgradedClusterIdFloat , embeddingConfigFloat (getUrl (cohereEmbeddingsServer )), testTaskType );
182+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
163183
164184 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdFloat ).get ("endpoints" );
165185 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
166186 assertThat (serviceSettings , hasEntry ("embedding_type" , "float" ));
167187
168188 assertEmbeddingInference (upgradedClusterIdFloat , CohereEmbeddingType .FLOAT );
189+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
169190 delete (upgradedClusterIdFloat );
170191 }
192+ {
193+ // new endpoints use the V2 API which require the model to be set
194+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
195+ var jsonBody = Strings .format ("""
196+ {
197+ "service": "cohere",
198+ "service_settings": {
199+ "url": "%s",
200+ "api_key": "XXXX",
201+ "embedding_type": "int8"
202+ }
203+ }
204+ """ , getUrl (cohereEmbeddingsServer ));
205+
206+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , testTaskType ));
207+ assertThat (
208+ e .getMessage (),
209+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
210+ );
211+ }
171212
172213 delete (oldClusterIdFloat );
173214 delete (oldClusterIdInt8 );
174215 }
175216 }
176217
218+ private void assertVersionInPath (MockRequest request , String endpoint , ApiVersion apiVersion ) {
219+ switch (apiVersion ) {
220+ case V2 :
221+ assertEquals ("/v2/" + endpoint , request .getUri ().getPath ());
222+ break ;
223+ case V1 :
224+ assertEquals ("/v1/" + endpoint , request .getUri ().getPath ());
225+ break ;
226+ }
227+ }
228+
177229 void assertEmbeddingInference (String inferenceId , CohereEmbeddingType type ) throws IOException {
178230 switch (type ) {
179231 case INT8 :
@@ -191,9 +243,11 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro
191243 @ SuppressWarnings ("unchecked" )
192244 public void testRerank () throws IOException {
193245 var rerankSupported = oldClusterHasFeature (COHERE_RERANK_ADDED_TEST_FEATURE );
194- String old_cluster_endpoint_identifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
195246 assumeTrue ("Cohere rerank service supported" , rerankSupported );
196247
248+ String old_cluster_endpoint_identifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
249+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion .V2 : ApiVersion .V1 ;
250+
197251 final String oldClusterId = "old-cluster-rerank" ;
198252 final String upgradedClusterId = "upgraded-cluster-rerank" ;
199253
@@ -216,7 +270,6 @@ public void testRerank() throws IOException {
216270 assertThat (taskSettings , hasEntry ("top_n" , 3 ));
217271
218272 assertRerank (oldClusterId );
219-
220273 } else if (isUpgradedCluster ()) {
221274 // check old cluster model
222275 var configs = (List <Map <String , Object >>) get (testTaskType , oldClusterId ).get ("endpoints" );
@@ -227,6 +280,7 @@ public void testRerank() throws IOException {
227280 assertThat (taskSettings , hasEntry ("top_n" , 3 ));
228281
229282 assertRerank (oldClusterId );
283+ assertVersionInPath (cohereRerankServer .requests ().getLast (), "rerank" , oldClusterApiVersion );
230284
231285 // New endpoint
232286 cohereRerankServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (rerankResponse ()));
@@ -235,6 +289,27 @@ public void testRerank() throws IOException {
235289 assertThat (configs , hasSize (1 ));
236290
237291 assertRerank (upgradedClusterId );
292+ assertVersionInPath (cohereRerankServer .requests ().getLast (), "rerank" , ApiVersion .V2 );
293+
294+ {
295+ // new endpoints use the V2 API which require the model_id to be set
296+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
297+ var jsonBody = Strings .format ("""
298+ {
299+ "service": "cohere",
300+ "service_settings": {
301+ "url": "%s",
302+ "api_key": "XXXX"
303+ }
304+ }
305+ """ , getUrl (cohereEmbeddingsServer ));
306+
307+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , testTaskType ));
308+ assertThat (
309+ e .getMessage (),
310+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
311+ );
312+ }
238313
239314 delete (oldClusterId );
240315 delete (upgradedClusterId );
0 commit comments