99
1010import com .carrotsearch .randomizedtesting .annotations .Name ;
1111
12+ import org .elasticsearch .client .ResponseException ;
1213import org .elasticsearch .common .Strings ;
1314import org .elasticsearch .inference .TaskType ;
15+ import org .elasticsearch .test .http .MockRequest ;
1416import org .elasticsearch .test .http .MockResponse ;
1517import org .elasticsearch .test .http .MockWebServer ;
1618import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingType ;
2426
2527import static org .hamcrest .Matchers .anEmptyMap ;
2628import static org .hamcrest .Matchers .anyOf ;
29+ import static org .hamcrest .Matchers .containsString ;
2730import static org .hamcrest .Matchers .empty ;
2831import static org .hamcrest .Matchers .hasEntry ;
2932import static org .hamcrest .Matchers .hasSize ;
@@ -35,11 +38,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3538
3639 private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0" ;
3740 private static final String COHERE_RERANK_ADDED = "8.14.0" ;
38- private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0 " ;
41+ private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2 " ;
3942
4043 private static MockWebServer cohereEmbeddingsServer ;
4144 private static MockWebServer cohereRerankServer ;
4245
46+ private enum ApiVersion {
47+ V1 ,
48+ V2
49+ }
50+
4351 public CohereServiceUpgradeIT (@ Name ("upgradedNodes" ) int upgradedNodes ) {
4452 super (upgradedNodes );
4553 }
@@ -64,14 +72,15 @@ public void testCohereEmbeddings() throws IOException {
6472 var embeddingsSupported = getOldClusterTestVersion ().onOrAfter (COHERE_EMBEDDINGS_ADDED );
6573 // `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS
6674 String oldClusterEndpointIdentifier = oldClusterHasFeature ("gte_v" + MODELS_RENAMED_TO_ENDPOINTS ) ? "endpoints" : "models" ;
67- assumeTrue ( "Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED , embeddingsSupported ) ;
75+ ApiVersion oldClusterApiVersion = oldClusterHasFeature ( COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion . V2 : ApiVersion . V1 ;
6876
6977 final String oldClusterIdInt8 = "old-cluster-embeddings-int8" ;
7078 final String oldClusterIdFloat = "old-cluster-embeddings-float" ;
7179
7280 var testTaskType = TaskType .TEXT_EMBEDDING ;
7381
7482 if (isOldCluster ()) {
83+
7584 // queue a response as PUT will call the service
7685 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
7786 put (oldClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
@@ -129,13 +138,29 @@ public void testCohereEmbeddings() throws IOException {
129138
130139 // Inference on old cluster models
131140 assertEmbeddingInference (oldClusterIdInt8 , CohereEmbeddingType .BYTE );
141+ assertVersionInPath (
142+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
143+ "embed" ,
144+ oldClusterApiVersion
145+ );
132146 assertEmbeddingInference (oldClusterIdFloat , CohereEmbeddingType .FLOAT );
147+ assertVersionInPath (
148+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
149+ "embed" ,
150+ oldClusterApiVersion
151+ );
133152
134153 {
135154 final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte" ;
136155
156+ // new endpoints use the V2 API
137157 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
138158 put (upgradedClusterIdByte , embeddingConfigByte (getUrl (cohereEmbeddingsServer )), testTaskType );
159+ assertVersionInPath (
160+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
161+ "embed" ,
162+ ApiVersion .V2
163+ );
139164
140165 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdByte ).get ("endpoints" );
141166 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
@@ -147,34 +172,86 @@ public void testCohereEmbeddings() throws IOException {
147172 {
148173 final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8" ;
149174
175+ // new endpoints use the V2 API
150176 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
151177 put (upgradedClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
178+ assertVersionInPath (
179+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
180+ "embed" ,
181+ ApiVersion .V2
182+ );
152183
153184 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdInt8 ).get ("endpoints" );
154185 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
155186 assertThat (serviceSettings , hasEntry ("embedding_type" , "byte" )); // int8 rewritten to byte
156187
157188 assertEmbeddingInference (upgradedClusterIdInt8 , CohereEmbeddingType .INT8 );
189+ assertVersionInPath (
190+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
191+ "embed" ,
192+ ApiVersion .V2
193+ );
158194 delete (upgradedClusterIdInt8 );
159195 }
160196 {
161197 final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float" ;
162198 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseFloat ()));
163199 put (upgradedClusterIdFloat , embeddingConfigFloat (getUrl (cohereEmbeddingsServer )), testTaskType );
200+ assertVersionInPath (
201+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
202+ "embed" ,
203+ ApiVersion .V2
204+ );
164205
165206 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdFloat ).get ("endpoints" );
166207 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
167208 assertThat (serviceSettings , hasEntry ("embedding_type" , "float" ));
168209
169210 assertEmbeddingInference (upgradedClusterIdFloat , CohereEmbeddingType .FLOAT );
211+ assertVersionInPath (
212+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
213+ "embed" ,
214+ ApiVersion .V2
215+ );
170216 delete (upgradedClusterIdFloat );
171217 }
218+ {
219+ // new endpoints use the V2 API which require the model to be set
220+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
221+ var jsonBody = Strings .format ("""
222+ {
223+ "service": "cohere",
224+ "service_settings": {
225+ "url": "%s",
226+ "api_key": "XXXX",
227+ "embedding_type": "int8"
228+ }
229+ }
230+ """ , getUrl (cohereEmbeddingsServer ));
231+
232+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , testTaskType ));
233+ assertThat (
234+ e .getMessage (),
235+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
236+ );
237+ }
172238
173239 delete (oldClusterIdFloat );
174240 delete (oldClusterIdInt8 );
175241 }
176242 }
177243
244+ private void assertVersionInPath (MockRequest request , String endpoint , ApiVersion apiVersion ) {
245+ switch (apiVersion ) {
246+ case V2 :
247+ assertEquals ("/v2/" + endpoint , request .getUri ().getPath ());
248+ break ;
249+ case V1 :
250+ assertEquals ("/v1/" + endpoint , request .getUri ().getPath ());
251+ break ;
252+ }
253+ }
254+
178255 void assertEmbeddingInference (String inferenceId , CohereEmbeddingType type ) throws IOException {
179256 switch (type ) {
180257 case INT8 :
@@ -195,6 +272,8 @@ public void testRerank() throws IOException {
195272 String old_cluster_endpoint_identifier = oldClusterHasFeature ("gte_v" + MODELS_RENAMED_TO_ENDPOINTS ) ? "endpoints" : "models" ;
196273 assumeTrue ("Cohere rerank service added in " + COHERE_RERANK_ADDED , rerankSupported );
197274
275+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion .V2 : ApiVersion .V1 ;
276+
198277 final String oldClusterId = "old-cluster-rerank" ;
199278 final String upgradedClusterId = "upgraded-cluster-rerank" ;
200279
@@ -217,7 +296,6 @@ public void testRerank() throws IOException {
217296 assertThat (taskSettings , hasEntry ("top_n" , 3 ));
218297
219298 assertRerank (oldClusterId );
220-
221299 } else if (isUpgradedCluster ()) {
222300 // check old cluster model
223301 var configs = (List <Map <String , Object >>) get (testTaskType , oldClusterId ).get ("endpoints" );
@@ -228,6 +306,11 @@ public void testRerank() throws IOException {
228306 assertThat (taskSettings , hasEntry ("top_n" , 3 ));
229307
230308 assertRerank (oldClusterId );
309+ assertVersionInPath (
310+ cohereRerankServer .requests ().get (cohereRerankServer .requests ().size () - 1 ),
311+ "rerank" ,
312+ oldClusterApiVersion
313+ );
231314
232315 // New endpoint
233316 cohereRerankServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (rerankResponse ()));
@@ -236,6 +319,27 @@ public void testRerank() throws IOException {
236319 assertThat (configs , hasSize (1 ));
237320
238321 assertRerank (upgradedClusterId );
322+ assertVersionInPath (cohereRerankServer .requests ().get (cohereRerankServer .requests ().size () - 1 ), "rerank" , ApiVersion .V2 );
323+
324+ {
325+ // new endpoints use the V2 API which require the model_id to be set
326+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
327+ var jsonBody = Strings .format ("""
328+ {
329+ "service": "cohere",
330+ "service_settings": {
331+ "url": "%s",
332+ "api_key": "XXXX"
333+ }
334+ }
335+ """ , getUrl (cohereEmbeddingsServer ));
336+
337+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , testTaskType ));
338+ assertThat (
339+ e .getMessage (),
340+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
341+ );
342+ }
239343
240344 delete (oldClusterId );
241345 delete (upgradedClusterId );
0 commit comments