4343import static org .elasticsearch .xpack .inference .services .ServiceFields .MAX_INPUT_TOKENS ;
4444import static org .elasticsearch .xpack .inference .services .ServiceFields .SIMILARITY ;
4545import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalMap ;
46+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalPositiveInteger ;
4647import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredMap ;
4748import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredString ;
4849import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractSimilarity ;
5253import static org .elasticsearch .xpack .inference .services .ServiceUtils .validateMapStringValues ;
5354
5455public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings , CustomRateLimitServiceSettings {
56+
5557 public static final String NAME = "custom_service_settings" ;
5658 public static final String URL = "url" ;
59+ public static final String BATCH_SIZE = "batch_size" ;
5760 public static final String HEADERS = "headers" ;
5861 public static final String REQUEST = "request" ;
5962 public static final String RESPONSE = "response" ;
6063 public static final String JSON_PARSER = "json_parser" ;
6164
6265 private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings (10_000 );
6366 private static final String RESPONSE_SCOPE = String .join ("." , ModelConfigurations .SERVICE_SETTINGS , RESPONSE );
67+ private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10 ;
6468
6569 public static CustomServiceSettings fromMap (
6670 Map <String , Object > map ,
@@ -106,6 +110,8 @@ public static CustomServiceSettings fromMap(
106110 context
107111 );
108112
113+ var batchSize = extractOptionalPositiveInteger (map , BATCH_SIZE , ModelConfigurations .SERVICE_SETTINGS , validationException );
114+
109115 if (responseParserMap == null || jsonParserMap == null ) {
110116 throw validationException ;
111117 }
@@ -124,7 +130,8 @@ public static CustomServiceSettings fromMap(
124130 queryParams ,
125131 requestContentString ,
126132 responseJsonParser ,
127- rateLimitSettings
133+ rateLimitSettings ,
134+ batchSize
128135 );
129136 }
130137
@@ -142,7 +149,6 @@ public record TextEmbeddingSettings(
142149 null ,
143150 DenseVectorFieldMapper .ElementType .FLOAT
144151 );
145-
146152 // This refers to settings that are not related to the text embedding task type (all the settings should be null)
147153 public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings (null , null , null , null );
148154
@@ -196,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196202 private final String requestContentString ;
197203 private final CustomResponseParser responseJsonParser ;
198204 private final RateLimitSettings rateLimitSettings ;
205+ private final int batchSize ;
199206
200207 public CustomServiceSettings (
201208 TextEmbeddingSettings textEmbeddingSettings ,
@@ -205,6 +212,19 @@ public CustomServiceSettings(
205212 String requestContentString ,
206213 CustomResponseParser responseJsonParser ,
207214 @ Nullable RateLimitSettings rateLimitSettings
215+ ) {
216+ this (textEmbeddingSettings , url , headers , queryParameters , requestContentString , responseJsonParser , rateLimitSettings , null );
217+ }
218+
219+ public CustomServiceSettings (
220+ TextEmbeddingSettings textEmbeddingSettings ,
221+ String url ,
222+ @ Nullable Map <String , String > headers ,
223+ @ Nullable QueryParameters queryParameters ,
224+ String requestContentString ,
225+ CustomResponseParser responseJsonParser ,
226+ @ Nullable RateLimitSettings rateLimitSettings ,
227+ @ Nullable Integer batchSize
208228 ) {
209229 this .textEmbeddingSettings = Objects .requireNonNull (textEmbeddingSettings );
210230 this .url = Objects .requireNonNull (url );
@@ -213,6 +233,7 @@ public CustomServiceSettings(
213233 this .requestContentString = Objects .requireNonNull (requestContentString );
214234 this .responseJsonParser = Objects .requireNonNull (responseJsonParser );
215235 this .rateLimitSettings = Objects .requireNonNullElse (rateLimitSettings , DEFAULT_RATE_LIMIT_SETTINGS );
236+ this .batchSize = Objects .requireNonNullElse (batchSize , DEFAULT_EMBEDDING_BATCH_SIZE );
216237 }
217238
218239 public CustomServiceSettings (StreamInput in ) throws IOException {
@@ -223,12 +244,20 @@ public CustomServiceSettings(StreamInput in) throws IOException {
223244 requestContentString = in .readString ();
224245 responseJsonParser = in .readNamedWriteable (CustomResponseParser .class );
225246 rateLimitSettings = new RateLimitSettings (in );
247+
226248 if (in .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING )
227249 && in .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 ) == false ) {
228250 // Read the error parsing fields for backwards compatibility
229251 in .readString ();
230252 in .readString ();
231253 }
254+
255+ if (in .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE )
256+ || in .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
257+ batchSize = in .readVInt ();
258+ } else {
259+ batchSize = DEFAULT_EMBEDDING_BATCH_SIZE ;
260+ }
232261 }
233262
234263 @ Override
@@ -276,6 +305,10 @@ public CustomResponseParser getResponseJsonParser() {
276305 return responseJsonParser ;
277306 }
278307
308+ public int getBatchSize () {
309+ return batchSize ;
310+ }
311+
279312 @ Override
280313 public RateLimitSettings rateLimitSettings () {
281314 return rateLimitSettings ;
@@ -321,6 +354,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
321354
322355 rateLimitSettings .toXContent (builder , params );
323356
357+ builder .field (BATCH_SIZE , batchSize );
358+
324359 return builder ;
325360 }
326361
@@ -343,12 +378,18 @@ public void writeTo(StreamOutput out) throws IOException {
343378 out .writeString (requestContentString );
344379 out .writeNamedWriteable (responseJsonParser );
345380 rateLimitSettings .writeTo (out );
381+
346382 if (out .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING )
347383 && out .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 ) == false ) {
348384 // Write empty strings for backwards compatibility for the error parsing fields
349385 out .writeString ("" );
350386 out .writeString ("" );
351387 }
388+
389+ if (out .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE )
390+ || out .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
391+ out .writeVInt (batchSize );
392+ }
352393 }
353394
354395 @ Override
@@ -362,7 +403,8 @@ public boolean equals(Object o) {
362403 && Objects .equals (queryParameters , that .queryParameters )
363404 && Objects .equals (requestContentString , that .requestContentString )
364405 && Objects .equals (responseJsonParser , that .responseJsonParser )
365- && Objects .equals (rateLimitSettings , that .rateLimitSettings );
406+ && Objects .equals (rateLimitSettings , that .rateLimitSettings )
407+ && Objects .equals (batchSize , that .batchSize );
366408 }
367409
368410 @ Override
@@ -374,7 +416,8 @@ public int hashCode() {
374416 queryParameters ,
375417 requestContentString ,
376418 responseJsonParser ,
377- rateLimitSettings
419+ rateLimitSettings ,
420+ batchSize
378421 );
379422 }
380423
0 commit comments