@@ -39,10 +39,12 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3939 // TODO: replace with proper test features
4040 private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0" ;
4141 private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0" ;
42+ private static final String COHERE_COMPLETIONS_ADDED_TEST_FEATURE = "gte_v8.15.0" ;
4243 private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2" ;
4344
4445 private static MockWebServer cohereEmbeddingsServer ;
4546 private static MockWebServer cohereRerankServer ;
47+ private static MockWebServer cohereCompletionsServer ;
4648
4749 private enum ApiVersion {
4850 V1 ,
@@ -60,12 +62,16 @@ public static void startWebServer() throws IOException {
6062
6163 cohereRerankServer = new MockWebServer ();
6264 cohereRerankServer .start ();
65+
66+ cohereCompletionsServer = new MockWebServer ();
67+ cohereCompletionsServer .start ();
6368 }
6469
6570 @ AfterClass
6671 public static void shutdown () {
6772 cohereEmbeddingsServer .close ();
6873 cohereRerankServer .close ();
74+ cohereCompletionsServer .close ();
6975 }
7076
7177 @ SuppressWarnings ("unchecked" )
@@ -326,6 +332,80 @@ private void assertRerank(String inferenceId) throws IOException {
326332 assertThat (inferenceMap .entrySet (), not (empty ()));
327333 }
328334
335+ @ SuppressWarnings ("unchecked" )
336+ public void testCohereCompletions () throws IOException {
337+ var completionsSupported = oldClusterHasFeature (COHERE_COMPLETIONS_ADDED_TEST_FEATURE );
338+ assumeTrue ("Cohere completions not supported" , completionsSupported );
339+
340+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion .V2 : ApiVersion .V1 ;
341+
342+ final String oldClusterId = "old-cluster-completions" ;
343+
344+ if (isOldCluster ()) {
345+ // queue a response as PUT will call the service
346+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (oldClusterApiVersion )));
347+ put (oldClusterId , completionsConfig (getUrl (cohereCompletionsServer )), TaskType .COMPLETION );
348+
349+ var configs = (List <Map <String , Object >>) get (TaskType .COMPLETION , oldClusterId ).get ("endpoints" );
350+ assertThat (configs , hasSize (1 ));
351+ assertEquals ("cohere" , configs .get (0 ).get ("service" ));
352+ var serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
353+ assertThat (serviceSettings , hasEntry ("model_id" , "command" ));
354+ } else if (isMixedCluster ()) {
355+ var configs = (List <Map <String , Object >>) get (TaskType .COMPLETION , oldClusterId ).get ("endpoints" );
356+ assertThat (configs , hasSize (1 ));
357+ assertEquals ("cohere" , configs .get (0 ).get ("service" ));
358+ var serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
359+ assertThat (serviceSettings , hasEntry ("model_id" , "command" ));
360+ } else if (isUpgradedCluster ()) {
361+ // check old cluster model
362+ var configs = (List <Map <String , Object >>) get (TaskType .COMPLETION , oldClusterId ).get ("endpoints" );
363+ var serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
364+ assertThat (serviceSettings , hasEntry ("model_id" , "command" ));
365+
366+ final String newClusterId = "new-cluster-completions" ;
367+ {
368+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (oldClusterApiVersion )));
369+ var inferenceMap = inference (oldClusterId , TaskType .COMPLETION , "some text" );
370+ assertThat (inferenceMap .entrySet (), not (empty ()));
371+ assertVersionInPath (cohereCompletionsServer .requests ().getLast (), "chat" , oldClusterApiVersion );
372+ }
373+ {
374+ // new cluster uses the V2 API
375+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (ApiVersion .V2 )));
376+ put (newClusterId , completionsConfig (getUrl (cohereCompletionsServer )), TaskType .COMPLETION );
377+
378+ cohereCompletionsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (completionsResponse (ApiVersion .V2 )));
379+ var inferenceMap = inference (newClusterId , TaskType .COMPLETION , "some text" );
380+ assertThat (inferenceMap .entrySet (), not (empty ()));
381+ assertVersionInPath (cohereCompletionsServer .requests ().getLast (), "chat" , ApiVersion .V2 );
382+ }
383+
384+ {
385+ // new endpoints use the V2 API which require the model to be set
386+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
387+ var jsonBody = Strings .format ("""
388+ {
389+ "service": "cohere",
390+ "service_settings": {
391+ "url": "%s",
392+ "api_key": "XXXX"
393+ }
394+ }
395+ """ , getUrl (cohereEmbeddingsServer ));
396+
397+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , TaskType .COMPLETION ));
398+ assertThat (
399+ e .getMessage (),
400+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
401+ );
402+ }
403+
404+ delete (oldClusterId );
405+ delete (newClusterId );
406+ }
407+ }
408+
329409 private String embeddingConfigByte (String url ) {
330410 return embeddingConfigTemplate (url , "byte" );
331411 }
@@ -451,4 +531,86 @@ private String rerankResponse() {
451531 """ ;
452532 }
453533
534+ private String completionsConfig (String url ) {
535+ return Strings .format ("""
536+ {
537+ "service": "cohere",
538+ "service_settings": {
539+ "api_key": "XXXX",
540+ "model_id": "command",
541+ "url": "%s"
542+ }
543+ }
544+ """ , url );
545+ }
546+
547+ private String completionsResponse (ApiVersion version ) {
548+ return switch (version ) {
549+ case V1 -> v1CompletionsResponse ();
550+ case V2 -> v2CompletionsResponse ();
551+ };
552+ }
553+
554+ private String v1CompletionsResponse () {
555+ return """
556+ {
557+ "response_id": "some id",
558+ "text": "result",
559+ "generation_id": "some id",
560+ "chat_history": [
561+ {
562+ "role": "USER",
563+ "message": "some input"
564+ },
565+ {
566+ "role": "CHATBOT",
567+ "message": "v1 response from the llm"
568+ }
569+ ],
570+ "finish_reason": "COMPLETE",
571+ "meta": {
572+ "api_version": {
573+ "version": "1"
574+ },
575+ "billed_units": {
576+ "input_tokens": 4,
577+ "output_tokens": 191
578+ },
579+ "tokens": {
580+ "input_tokens": 70,
581+ "output_tokens": 191
582+ }
583+ }
584+ }
585+ """ ;
586+ }
587+
588+ private String v2CompletionsResponse () {
589+ return """
590+ {
591+ "id": "c14c80c3-18eb-4519-9460-6c92edd8cfb4",
592+ "finish_reason": "COMPLETE",
593+ "message": {
594+ "role": "assistant",
595+ "content": [
596+ {
597+ "type": "text",
598+ "text": "v2 response from the LLM"
599+ }
600+ ]
601+ },
602+ "usage": {
603+ "billed_units": {
604+ "input_tokens": 1,
605+ "output_tokens": 2
606+ },
607+ "tokens": {
608+ "input_tokens": 3,
609+ "output_tokens": 4
610+ }
611+ }
612+ }
613+ """ ;
614+ }
615+
454616}
0 commit comments