4343import static org .elasticsearch .xpack .inference .services .ServiceFields .MAX_INPUT_TOKENS ;
4444import static org .elasticsearch .xpack .inference .services .ServiceFields .SIMILARITY ;
4545import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalMap ;
46+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalPositiveInteger ;
4647import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredMap ;
4748import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredString ;
4849import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractSimilarity ;
5253import static org .elasticsearch .xpack .inference .services .ServiceUtils .validateMapStringValues ;
5354
5455public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings , CustomRateLimitServiceSettings {
56+
5557 public static final String NAME = "custom_service_settings" ;
5658 public static final String URL = "url" ;
59+ public static final String BATCH_SIZE = "batch_size" ;
5760 public static final String HEADERS = "headers" ;
5861 public static final String REQUEST = "request" ;
5962 public static final String RESPONSE = "response" ;
6063 public static final String JSON_PARSER = "json_parser" ;
6164
6265 private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings (10_000 );
6366 private static final String RESPONSE_SCOPE = String .join ("." , ModelConfigurations .SERVICE_SETTINGS , RESPONSE );
67+ private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10 ;
6468
6569 public static CustomServiceSettings fromMap (
6670 Map <String , Object > map ,
@@ -106,6 +110,8 @@ public static CustomServiceSettings fromMap(
106110 context
107111 );
108112
113+ var batchSize = extractOptionalPositiveInteger (map , BATCH_SIZE , ModelConfigurations .SERVICE_SETTINGS , validationException );
114+
109115 if (responseParserMap == null || jsonParserMap == null ) {
110116 throw validationException ;
111117 }
@@ -124,7 +130,8 @@ public static CustomServiceSettings fromMap(
124130 queryParams ,
125131 requestContentString ,
126132 responseJsonParser ,
127- rateLimitSettings
133+ rateLimitSettings ,
134+ batchSize
128135 );
129136 }
130137
@@ -142,7 +149,6 @@ public record TextEmbeddingSettings(
142149 null ,
143150 DenseVectorFieldMapper .ElementType .FLOAT
144151 );
145-
146152 // This refers to settings that are not related to the text embedding task type (all the settings should be null)
147153 public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings (null , null , null , null );
148154
@@ -196,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196202 private final String requestContentString ;
197203 private final CustomResponseParser responseJsonParser ;
198204 private final RateLimitSettings rateLimitSettings ;
205+ private final int batchSize ;
199206
200207 public CustomServiceSettings (
201208 TextEmbeddingSettings textEmbeddingSettings ,
@@ -205,6 +212,19 @@ public CustomServiceSettings(
205212 String requestContentString ,
206213 CustomResponseParser responseJsonParser ,
207214 @ Nullable RateLimitSettings rateLimitSettings
215+ ) {
216+ this (textEmbeddingSettings , url , headers , queryParameters , requestContentString , responseJsonParser , rateLimitSettings , null );
217+ }
218+
219+ public CustomServiceSettings (
220+ TextEmbeddingSettings textEmbeddingSettings ,
221+ String url ,
222+ @ Nullable Map <String , String > headers ,
223+ @ Nullable QueryParameters queryParameters ,
224+ String requestContentString ,
225+ CustomResponseParser responseJsonParser ,
226+ @ Nullable RateLimitSettings rateLimitSettings ,
227+ @ Nullable Integer batchSize
208228 ) {
209229 this .textEmbeddingSettings = Objects .requireNonNull (textEmbeddingSettings );
210230 this .url = Objects .requireNonNull (url );
@@ -213,6 +233,7 @@ public CustomServiceSettings(
213233 this .requestContentString = Objects .requireNonNull (requestContentString );
214234 this .responseJsonParser = Objects .requireNonNull (responseJsonParser );
215235 this .rateLimitSettings = Objects .requireNonNullElse (rateLimitSettings , DEFAULT_RATE_LIMIT_SETTINGS );
236+ this .batchSize = Objects .requireNonNullElse (batchSize , DEFAULT_EMBEDDING_BATCH_SIZE );
216237 }
217238
218239 public CustomServiceSettings (StreamInput in ) throws IOException {
@@ -223,11 +244,18 @@ public CustomServiceSettings(StreamInput in) throws IOException {
223244 requestContentString = in .readString ();
224245 responseJsonParser = in .readNamedWriteable (CustomResponseParser .class );
225246 rateLimitSettings = new RateLimitSettings (in );
247+
226248 if (in .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 )) {
227249 // Read the error parsing fields for backwards compatibility
228250 in .readString ();
229251 in .readString ();
230252 }
253+
254+ if (in .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
255+ batchSize = in .readVInt ();
256+ } else {
257+ batchSize = DEFAULT_EMBEDDING_BATCH_SIZE ;
258+ }
231259 }
232260
233261 @ Override
@@ -275,6 +303,10 @@ public CustomResponseParser getResponseJsonParser() {
275303 return responseJsonParser ;
276304 }
277305
306+ public int getBatchSize () {
307+ return batchSize ;
308+ }
309+
278310 @ Override
279311 public RateLimitSettings rateLimitSettings () {
280312 return rateLimitSettings ;
@@ -320,6 +352,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
320352
321353 rateLimitSettings .toXContent (builder , params );
322354
355+ builder .field (BATCH_SIZE , batchSize );
356+
323357 return builder ;
324358 }
325359
@@ -342,11 +376,16 @@ public void writeTo(StreamOutput out) throws IOException {
342376 out .writeString (requestContentString );
343377 out .writeNamedWriteable (responseJsonParser );
344378 rateLimitSettings .writeTo (out );
379+
345380 if (out .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 )) {
346381 // Write empty strings for backwards compatibility for the error parsing fields
347382 out .writeString ("" );
348383 out .writeString ("" );
349384 }
385+
386+ if (out .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
387+ out .writeVInt (batchSize );
388+ }
350389 }
351390
352391 @ Override
@@ -360,7 +399,8 @@ public boolean equals(Object o) {
360399 && Objects .equals (queryParameters , that .queryParameters )
361400 && Objects .equals (requestContentString , that .requestContentString )
362401 && Objects .equals (responseJsonParser , that .responseJsonParser )
363- && Objects .equals (rateLimitSettings , that .rateLimitSettings );
402+ && Objects .equals (rateLimitSettings , that .rateLimitSettings )
403+ && Objects .equals (batchSize , that .batchSize );
364404 }
365405
366406 @ Override
@@ -372,7 +412,8 @@ public int hashCode() {
372412 queryParameters ,
373413 requestContentString ,
374414 responseJsonParser ,
375- rateLimitSettings
415+ rateLimitSettings ,
416+ batchSize
376417 );
377418 }
378419
0 commit comments