7
7
8
8
package org .elasticsearch .xpack .inference .services .cohere .rerank ;
9
9
10
+ import org .apache .logging .log4j .LogManager ;
11
+ import org .apache .logging .log4j .Logger ;
10
12
import org .elasticsearch .TransportVersion ;
11
13
import org .elasticsearch .TransportVersions ;
12
14
import org .elasticsearch .common .ValidationException ;
13
15
import org .elasticsearch .common .io .stream .StreamInput ;
14
16
import org .elasticsearch .common .io .stream .StreamOutput ;
17
+ import org .elasticsearch .core .Nullable ;
18
+ import org .elasticsearch .inference .ModelConfigurations ;
15
19
import org .elasticsearch .inference .ServiceSettings ;
20
+ import org .elasticsearch .inference .SimilarityMeasure ;
16
21
import org .elasticsearch .xcontent .XContentBuilder ;
17
22
import org .elasticsearch .xpack .inference .services .ConfigurationParseContext ;
18
- import org .elasticsearch .xpack .inference .services .cohere .CohereServiceSettings ;
23
+ import org .elasticsearch .xpack .inference .services .cohere .CohereRateLimitServiceSettings ;
24
+ import org .elasticsearch .xpack .inference .services .cohere .CohereService ;
19
25
import org .elasticsearch .xpack .inference .services .settings .FilteredXContentObject ;
26
+ import org .elasticsearch .xpack .inference .services .settings .RateLimitSettings ;
20
27
21
28
import java .io .IOException ;
29
+ import java .net .URI ;
22
30
import java .util .Map ;
23
31
import java .util .Objects ;
24
32
25
- public class CohereRerankServiceSettings extends FilteredXContentObject implements ServiceSettings {
33
+ import static org .elasticsearch .xpack .inference .services .ServiceFields .DIMENSIONS ;
34
+ import static org .elasticsearch .xpack .inference .services .ServiceFields .MAX_INPUT_TOKENS ;
35
+ import static org .elasticsearch .xpack .inference .services .ServiceFields .MODEL_ID ;
36
+ import static org .elasticsearch .xpack .inference .services .ServiceFields .URL ;
37
+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .convertToUri ;
38
+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .createOptionalUri ;
39
+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalString ;
40
+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractSimilarity ;
41
+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .removeAsType ;
42
+ import static org .elasticsearch .xpack .inference .services .cohere .CohereServiceSettings .DEFAULT_RATE_LIMIT_SETTINGS ;
43
+
44
+ public class CohereRerankServiceSettings extends FilteredXContentObject implements ServiceSettings , CohereRateLimitServiceSettings {
26
45
public static final String NAME = "cohere_rerank_service_settings" ;
27
46
28
- public static CohereRerankServiceSettings fromMap (Map <String , Object > map , ConfigurationParseContext parseContext ) {
47
+ private static final Logger logger = LogManager .getLogger (CohereRerankServiceSettings .class );
48
+
49
+ public static CohereRerankServiceSettings fromMap (Map <String , Object > map , ConfigurationParseContext context ) {
29
50
ValidationException validationException = new ValidationException ();
30
- var commonServiceSettings = CohereServiceSettings .fromMap (map , parseContext );
51
+
52
+ String url = extractOptionalString (map , URL , ModelConfigurations .SERVICE_SETTINGS , validationException );
53
+
54
+ // We need to extract/remove those fields to avoid unknown service settings errors
55
+ extractSimilarity (map , ModelConfigurations .SERVICE_SETTINGS , validationException );
56
+ removeAsType (map , DIMENSIONS , Integer .class );
57
+ removeAsType (map , MAX_INPUT_TOKENS , Integer .class );
58
+
59
+ URI uri = convertToUri (url , URL , ModelConfigurations .SERVICE_SETTINGS , validationException );
60
+ String modelId = extractOptionalString (map , MODEL_ID , ModelConfigurations .SERVICE_SETTINGS , validationException );
61
+ RateLimitSettings rateLimitSettings = RateLimitSettings .of (
62
+ map ,
63
+ DEFAULT_RATE_LIMIT_SETTINGS ,
64
+ validationException ,
65
+ CohereService .NAME ,
66
+ context
67
+ );
31
68
32
69
if (validationException .validationErrors ().isEmpty () == false ) {
33
70
throw validationException ;
34
71
}
35
72
36
- return new CohereRerankServiceSettings (commonServiceSettings );
73
+ return new CohereRerankServiceSettings (uri , modelId , rateLimitSettings );
37
74
}
38
75
39
- private final CohereServiceSettings commonSettings ;
76
+ private final URI uri ;
77
+
78
+ private final String modelId ;
79
+
80
+ private final RateLimitSettings rateLimitSettings ;
81
+
82
+ public CohereRerankServiceSettings (@ Nullable URI uri , @ Nullable String modelId , @ Nullable RateLimitSettings rateLimitSettings ) {
83
+ this .uri = uri ;
84
+ this .modelId = modelId ;
85
+ this .rateLimitSettings = Objects .requireNonNullElse (rateLimitSettings , DEFAULT_RATE_LIMIT_SETTINGS );
86
+ }
40
87
41
- public CohereRerankServiceSettings (CohereServiceSettings commonSettings ) {
42
- this . commonSettings = commonSettings ;
88
+ public CohereRerankServiceSettings (@ Nullable String url , @ Nullable String modelId , @ Nullable RateLimitSettings rateLimitSettings ) {
89
+ this ( createOptionalUri ( url ), modelId , rateLimitSettings ) ;
43
90
}
44
91
45
92
public CohereRerankServiceSettings (StreamInput in ) throws IOException {
46
- commonSettings = new CohereServiceSettings (in );
93
+ this .uri = createOptionalUri (in .readOptionalString ());
94
+
95
+ if (in .getTransportVersion ().before (TransportVersions .ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED )) {
96
+ // An older node sends these fields, so we need to skip them to progress through the serialized data
97
+ in .readOptionalEnum (SimilarityMeasure .class );
98
+ in .readOptionalVInt ();
99
+ in .readOptionalVInt ();
100
+ }
101
+
102
+ this .modelId = in .readOptionalString ();
103
+
104
+ if (in .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED )) {
105
+ this .rateLimitSettings = new RateLimitSettings (in );
106
+ } else {
107
+ this .rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS ;
108
+ }
109
+ }
110
+
111
+ public URI uri () {
112
+ return uri ;
113
+ }
114
+
115
+ public String modelId () {
116
+ return modelId ;
117
+ }
118
+
119
+ @ Override
120
+ public RateLimitSettings rateLimitSettings () {
121
+ return rateLimitSettings ;
47
122
}
48
123
49
124
@ Override
@@ -55,15 +130,23 @@ public String getWriteableName() {
55
130
public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
56
131
builder .startObject ();
57
132
58
- commonSettings . toXContentFragment (builder , params );
133
+ toXContentFragmentOfExposedFields (builder , params );
59
134
60
135
builder .endObject ();
61
136
return builder ;
62
137
}
63
138
64
139
@ Override
65
140
protected XContentBuilder toXContentFragmentOfExposedFields (XContentBuilder builder , Params params ) throws IOException {
66
- commonSettings .toXContentFragmentOfExposedFields (builder , params );
141
+ if (uri != null ) {
142
+ builder .field (URL , uri .toString ());
143
+ }
144
+
145
+ if (modelId != null ) {
146
+ builder .field (MODEL_ID , modelId );
147
+ }
148
+
149
+ rateLimitSettings .toXContent (builder , params );
67
150
68
151
return builder ;
69
152
}
@@ -75,23 +158,36 @@ public TransportVersion getMinimalSupportedVersion() {
75
158
76
159
@ Override
77
160
public void writeTo (StreamOutput out ) throws IOException {
78
- commonSettings .writeTo (out );
161
+ var uriToWrite = uri != null ? uri .toString () : null ;
162
+ out .writeOptionalString (uriToWrite );
163
+
164
+ if (out .getTransportVersion ().before (TransportVersions .ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED )) {
165
+ // An old node expects this data to be present, so we need to send at least the booleans
166
+ // indicating that the fields are not set
167
+ out .writeOptionalEnum (null );
168
+ out .writeOptionalVInt (null );
169
+ out .writeOptionalVInt (null );
170
+ }
171
+
172
+ out .writeOptionalString (modelId );
173
+
174
+ if (out .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED )) {
175
+ rateLimitSettings .writeTo (out );
176
+ }
79
177
}
80
178
81
179
@ Override
82
- public boolean equals (Object o ) {
83
- if (this == o ) return true ;
84
- if (o == null || getClass () != o .getClass ()) return false ;
85
- CohereRerankServiceSettings that = (CohereRerankServiceSettings ) o ;
86
- return Objects .equals (commonSettings , that .commonSettings );
180
+ public boolean equals (Object object ) {
181
+ if (this == object ) return true ;
182
+ if (object == null || getClass () != object .getClass ()) return false ;
183
+ CohereRerankServiceSettings that = (CohereRerankServiceSettings ) object ;
184
+ return Objects .equals (uri , that .uri )
185
+ && Objects .equals (modelId , that .modelId )
186
+ && Objects .equals (rateLimitSettings , that .rateLimitSettings );
87
187
}
88
188
89
189
@ Override
90
190
public int hashCode () {
91
- return Objects .hash (commonSettings );
92
- }
93
-
94
- public CohereServiceSettings getCommonSettings () {
95
- return commonSettings ;
191
+ return Objects .hash (uri , modelId , rateLimitSettings );
96
192
}
97
193
}
0 commit comments