2020import org .elasticsearch .inference .SimilarityMeasure ;
2121import org .elasticsearch .xcontent .XContentBuilder ;
2222import org .elasticsearch .xpack .inference .services .ConfigurationParseContext ;
23+ import org .elasticsearch .xpack .inference .services .ServiceUtils ;
2324import org .elasticsearch .xpack .inference .services .settings .FilteredXContentObject ;
2425import org .elasticsearch .xpack .inference .services .settings .RateLimitSettings ;
2526
2627import java .io .IOException ;
2728import java .net .URI ;
29+ import java .util .EnumSet ;
30+ import java .util .Locale ;
2831import java .util .Map ;
2932import java .util .Objects ;
3033
@@ -43,6 +46,18 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser
4346 public static final String NAME = "cohere_service_settings" ;
4447 public static final String OLD_MODEL_ID_FIELD = "model" ;
4548 public static final String MODEL_ID = "model_id" ;
49+ public static final String API_VERSION = "api_version" ;
50+ public static final String MODEL_REQUIRED_FOR_V2_API = "The [service_settings.model_id] field is required for the Cohere V2 API." ;
51+
52+ public enum CohereApiVersion {
53+ V1 ,
54+ V2 ;
55+
56+ public static CohereApiVersion fromString (String name ) {
57+ return valueOf (name .trim ().toUpperCase (Locale .ROOT ));
58+ }
59+ }
60+
4661 private static final Logger logger = LogManager .getLogger (CohereServiceSettings .class );
4762 // Production key rate limits for all endpoints: https://docs.cohere.com/docs/going-live#production-key-specifications
4863 // 10K requests a minute
@@ -72,11 +87,53 @@ public static CohereServiceSettings fromMap(Map<String, Object> map, Configurati
7287 logger .info ("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead." );
7388 }
7489
90+ var resolvedModelId = modelId (oldModelId , modelId );
91+ var apiVersion = apiVersionFromMap (map , context , validationException );
92+ if (apiVersion == CohereApiVersion .V2 ) {
93+ if (resolvedModelId == null ) {
94+ validationException .addValidationError (MODEL_REQUIRED_FOR_V2_API );
95+ }
96+ }
97+
7598 if (validationException .validationErrors ().isEmpty () == false ) {
7699 throw validationException ;
77100 }
78101
79- return new CohereServiceSettings (uri , similarity , dims , maxInputTokens , modelId (oldModelId , modelId ), rateLimitSettings );
102+ return new CohereServiceSettings (
103+ uri ,
104+ similarity ,
105+ dims ,
106+ maxInputTokens ,
107+ modelId (oldModelId , modelId ),
108+ rateLimitSettings ,
109+ apiVersion
110+ );
111+ }
112+
113+ public static CohereApiVersion apiVersionFromMap (
114+ Map <String , Object > map ,
115+ ConfigurationParseContext context ,
116+ ValidationException validationException
117+ ) {
118+ return switch (context ) {
119+ case REQUEST -> CohereApiVersion .V2 ; // new endpoints all use the V2 API.
120+ case PERSISTENT -> {
121+ var apiVersion = ServiceUtils .extractOptionalEnum (
122+ map ,
123+ API_VERSION ,
124+ ModelConfigurations .SERVICE_SETTINGS ,
125+ CohereApiVersion ::fromString ,
126+ EnumSet .allOf (CohereApiVersion .class ),
127+ validationException
128+ );
129+
130+ if (apiVersion == null ) {
131+ yield CohereApiVersion .V1 ; // If the API version is not persisted then it must be V1
132+ } else {
133+ yield apiVersion ;
134+ }
135+ }
136+ };
80137 }
81138
82139 private static String modelId (@ Nullable String model , @ Nullable String modelId ) {
@@ -89,21 +146,24 @@ private static String modelId(@Nullable String model, @Nullable String modelId)
89146 private final Integer maxInputTokens ;
90147 private final String modelId ;
91148 private final RateLimitSettings rateLimitSettings ;
149+ private final CohereApiVersion apiVersion ;
92150
93151 public CohereServiceSettings (
94152 @ Nullable URI uri ,
95153 @ Nullable SimilarityMeasure similarity ,
96154 @ Nullable Integer dimensions ,
97155 @ Nullable Integer maxInputTokens ,
98156 @ Nullable String modelId ,
99- @ Nullable RateLimitSettings rateLimitSettings
157+ @ Nullable RateLimitSettings rateLimitSettings ,
158+ CohereApiVersion apiVersion
100159 ) {
101160 this .uri = uri ;
102161 this .similarity = similarity ;
103162 this .dimensions = dimensions ;
104163 this .maxInputTokens = maxInputTokens ;
105164 this .modelId = modelId ;
106165 this .rateLimitSettings = Objects .requireNonNullElse (rateLimitSettings , DEFAULT_RATE_LIMIT_SETTINGS );
166+ this .apiVersion = apiVersion ;
107167 }
108168
109169 public CohereServiceSettings (
@@ -112,9 +172,10 @@ public CohereServiceSettings(
112172 @ Nullable Integer dimensions ,
113173 @ Nullable Integer maxInputTokens ,
114174 @ Nullable String modelId ,
115- @ Nullable RateLimitSettings rateLimitSettings
175+ @ Nullable RateLimitSettings rateLimitSettings ,
176+ CohereApiVersion apiVersion
116177 ) {
117- this (createOptionalUri (url ), similarity , dimensions , maxInputTokens , modelId , rateLimitSettings );
178+ this (createOptionalUri (url ), similarity , dimensions , maxInputTokens , modelId , rateLimitSettings , apiVersion );
118179 }
119180
120181 public CohereServiceSettings (StreamInput in ) throws IOException {
@@ -129,18 +190,29 @@ public CohereServiceSettings(StreamInput in) throws IOException {
129190 } else {
130191 rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS ;
131192 }
193+ if (in .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_COHERE_API_VERSION )
194+ || in .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_COHERE_API_VERSION )) {
195+ this .apiVersion = in .readEnum (CohereServiceSettings .CohereApiVersion .class );
196+ } else {
197+ this .apiVersion = CohereServiceSettings .CohereApiVersion .V1 ;
198+ }
132199 }
133200
134201 // should only be used for testing, public because it's accessed outside of the package
135- public CohereServiceSettings () {
136- this ((URI ) null , null , null , null , null , null );
202+ public CohereServiceSettings (CohereApiVersion apiVersion ) {
203+ this ((URI ) null , null , null , null , null , null , apiVersion );
137204 }
138205
139206 @ Override
140207 public RateLimitSettings rateLimitSettings () {
141208 return rateLimitSettings ;
142209 }
143210
211+ @ Override
212+ public CohereApiVersion apiVersion () {
213+ return apiVersion ;
214+ }
215+
144216 public URI uri () {
145217 return uri ;
146218 }
@@ -172,15 +244,14 @@ public String getWriteableName() {
172244 @ Override
173245 public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
174246 builder .startObject ();
175-
176247 toXContentFragment (builder , params );
177-
178248 builder .endObject ();
179249 return builder ;
180250 }
181251
182252 public XContentBuilder toXContentFragment (XContentBuilder builder , Params params ) throws IOException {
183- return toXContentFragmentOfExposedFields (builder , params );
253+ toXContentFragmentOfExposedFields (builder , params );
254+ return builder .field (API_VERSION , apiVersion ); // API version is persisted but not exposed to the user
184255 }
185256
186257 @ Override
@@ -222,6 +293,10 @@ public void writeTo(StreamOutput out) throws IOException {
222293 if (out .getTransportVersion ().onOrAfter (TransportVersions .V_8_15_0 )) {
223294 rateLimitSettings .writeTo (out );
224295 }
296+ if (out .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_COHERE_API_VERSION )
297+ || out .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_COHERE_API_VERSION )) {
298+ out .writeEnum (apiVersion );
299+ }
225300 }
226301
227302 @ Override
@@ -234,11 +309,12 @@ public boolean equals(Object o) {
234309 && Objects .equals (dimensions , that .dimensions )
235310 && Objects .equals (maxInputTokens , that .maxInputTokens )
236311 && Objects .equals (modelId , that .modelId )
237- && Objects .equals (rateLimitSettings , that .rateLimitSettings );
312+ && Objects .equals (rateLimitSettings , that .rateLimitSettings )
313+ && apiVersion == that .apiVersion ;
238314 }
239315
240316 @ Override
241317 public int hashCode () {
242- return Objects .hash (uri , similarity , dimensions , maxInputTokens , modelId , rateLimitSettings );
318+ return Objects .hash (uri , similarity , dimensions , maxInputTokens , modelId , rateLimitSettings , apiVersion );
243319 }
244320}
0 commit comments