1111
1212import org .elasticsearch .common .Strings ;
1313import org .elasticsearch .inference .TaskType ;
14+ import org .elasticsearch .test .http .MockRequest ;
1415import org .elasticsearch .test .http .MockResponse ;
1516import org .elasticsearch .test .http .MockWebServer ;
1617import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingType ;
@@ -36,10 +37,15 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3637 // TODO: replace with proper test features
3738 private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0" ;
3839 private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0" ;
40+ private static final String V2_API = "gte_v8.19.0" ;
3941
4042 private static MockWebServer cohereEmbeddingsServer ;
4143 private static MockWebServer cohereRerankServer ;
4244
45+ private enum ApiVersion {
46+ V1 , V2
47+ }
48+
4349 public CohereServiceUpgradeIT (@ Name ("upgradedNodes" ) int upgradedNodes ) {
4450 super (upgradedNodes );
4551 }
@@ -62,15 +68,18 @@ public static void shutdown() {
6268 @ SuppressWarnings ("unchecked" )
6369 public void testCohereEmbeddings () throws IOException {
6470 var embeddingsSupported = oldClusterHasFeature (COHERE_EMBEDDINGS_ADDED_TEST_FEATURE );
65- String oldClusterEndpointIdentifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
6671 assumeTrue ("Cohere embedding service supported" , embeddingsSupported );
6772
73+ String oldClusterEndpointIdentifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
74+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (V2_API ) ? ApiVersion .V2 : ApiVersion .V1 ;
75+
6876 final String oldClusterIdInt8 = "old-cluster-embeddings-int8" ;
6977 final String oldClusterIdFloat = "old-cluster-embeddings-float" ;
7078
7179 var testTaskType = TaskType .TEXT_EMBEDDING ;
7280
7381 if (isOldCluster ()) {
82+
7483 // queue a response as PUT will call the service
7584 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
7685 put (oldClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
@@ -128,13 +137,17 @@ public void testCohereEmbeddings() throws IOException {
128137
129138 // Inference on old cluster models
130139 assertEmbeddingInference (oldClusterIdInt8 , CohereEmbeddingType .BYTE );
140+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , oldClusterApiVersion );
131141 assertEmbeddingInference (oldClusterIdFloat , CohereEmbeddingType .FLOAT );
142+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , oldClusterApiVersion );
132143
133144 {
134145 final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte" ;
135146
147+ // new endpoints use the V2 API
136148 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
137149 put (upgradedClusterIdByte , embeddingConfigByte (getUrl (cohereEmbeddingsServer )), testTaskType );
150+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
138151
139152 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdByte ).get ("endpoints" );
140153 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
@@ -146,26 +159,31 @@ public void testCohereEmbeddings() throws IOException {
146159 {
147160 final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8" ;
148161
162+ // new endpoints use the V2 API
149163 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
150164 put (upgradedClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
165+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
151166
152167 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdInt8 ).get ("endpoints" );
153168 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
154169 assertThat (serviceSettings , hasEntry ("embedding_type" , "byte" )); // int8 rewritten to byte
155170
156171 assertEmbeddingInference (upgradedClusterIdInt8 , CohereEmbeddingType .INT8 );
172+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
157173 delete (upgradedClusterIdInt8 );
158174 }
159175 {
160176 final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float" ;
161177 cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseFloat ()));
162178 put (upgradedClusterIdFloat , embeddingConfigFloat (getUrl (cohereEmbeddingsServer )), testTaskType );
179+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
163180
164181 configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdFloat ).get ("endpoints" );
165182 serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
166183 assertThat (serviceSettings , hasEntry ("embedding_type" , "float" ));
167184
168185 assertEmbeddingInference (upgradedClusterIdFloat , CohereEmbeddingType .FLOAT );
186+ assertVersionInPath (cohereEmbeddingsServer .requests ().getLast (), "embed" , ApiVersion .V2 );
169187 delete (upgradedClusterIdFloat );
170188 }
171189
@@ -174,6 +192,17 @@ public void testCohereEmbeddings() throws IOException {
174192 }
175193 }
176194
195+ private void assertVersionInPath (MockRequest request , String endpoint , ApiVersion apiVersion ) {
196+ switch (apiVersion ) {
197+ case V2 :
198+ assertEquals ("/v2/" + endpoint , request .getUri ().getPath ());
199+ break ;
200+ case V1 :
201+ assertEquals ("/v1/" + endpoint , request .getUri ().getPath ());
202+ break ;
203+ }
204+ }
205+
177206 void assertEmbeddingInference (String inferenceId , CohereEmbeddingType type ) throws IOException {
178207 switch (type ) {
179208 case INT8 :
@@ -191,9 +220,11 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro
191220 @ SuppressWarnings ("unchecked" )
192221 public void testRerank () throws IOException {
193222 var rerankSupported = oldClusterHasFeature (COHERE_RERANK_ADDED_TEST_FEATURE );
194- String old_cluster_endpoint_identifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
195223 assumeTrue ("Cohere rerank service supported" , rerankSupported );
196224
225+ String old_cluster_endpoint_identifier = oldClusterHasFeature (MODELS_RENAMED_TO_ENDPOINTS_FEATURE ) ? "endpoints" : "models" ;
226+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (V2_API ) ? ApiVersion .V2 : ApiVersion .V1 ;
227+
197228 final String oldClusterId = "old-cluster-rerank" ;
198229 final String upgradedClusterId = "upgraded-cluster-rerank" ;
199230
@@ -216,7 +247,6 @@ public void testRerank() throws IOException {
216247 assertThat (taskSettings , hasEntry ("top_n" , 3 ));
217248
218249 assertRerank (oldClusterId );
219-
220250 } else if (isUpgradedCluster ()) {
221251 // check old cluster model
222252 var configs = (List <Map <String , Object >>) get (testTaskType , oldClusterId ).get ("endpoints" );
@@ -227,6 +257,7 @@ public void testRerank() throws IOException {
227257 assertThat (taskSettings , hasEntry ("top_n" , 3 ));
228258
229259 assertRerank (oldClusterId );
260+ assertVersionInPath (cohereRerankServer .requests ().getLast (), "rerank" , oldClusterApiVersion );
230261
231262 // New endpoint
232263 cohereRerankServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (rerankResponse ()));
@@ -235,6 +266,7 @@ public void testRerank() throws IOException {
235266 assertThat (configs , hasSize (1 ));
236267
237268 assertRerank (upgradedClusterId );
269+ assertVersionInPath (cohereRerankServer .requests ().getLast (), "rerank" , ApiVersion .V2 );
238270
239271 delete (oldClusterId );
240272 delete (upgradedClusterId );
0 commit comments