4343import  static  org .elasticsearch .xpack .inference .services .ServiceFields .MAX_INPUT_TOKENS ;
4444import  static  org .elasticsearch .xpack .inference .services .ServiceFields .SIMILARITY ;
4545import  static  org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalMap ;
46+ import  static  org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalPositiveInteger ;
4647import  static  org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredMap ;
4748import  static  org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredString ;
4849import  static  org .elasticsearch .xpack .inference .services .ServiceUtils .extractSimilarity ;
5253import  static  org .elasticsearch .xpack .inference .services .ServiceUtils .validateMapStringValues ;
5354
5455public  class  CustomServiceSettings  extends  FilteredXContentObject  implements  ServiceSettings , CustomRateLimitServiceSettings  {
56+ 
5557    public  static  final  String  NAME  = "custom_service_settings" ;
5658    public  static  final  String  URL  = "url" ;
59+     public  static  final  String  BATCH_SIZE  = "batch_size" ;
5760    public  static  final  String  HEADERS  = "headers" ;
5861    public  static  final  String  REQUEST  = "request" ;
5962    public  static  final  String  RESPONSE  = "response" ;
6063    public  static  final  String  JSON_PARSER  = "json_parser" ;
6164
6265    private  static  final  RateLimitSettings  DEFAULT_RATE_LIMIT_SETTINGS  = new  RateLimitSettings (10_000 );
6366    private  static  final  String  RESPONSE_SCOPE  = String .join ("." , ModelConfigurations .SERVICE_SETTINGS , RESPONSE );
67+     private  static  final  int  DEFAULT_EMBEDDING_BATCH_SIZE  = 10 ;
6468
6569    public  static  CustomServiceSettings  fromMap (
6670        Map <String , Object > map ,
@@ -106,6 +110,8 @@ public static CustomServiceSettings fromMap(
106110            context 
107111        );
108112
113+         var  batchSize  = extractOptionalPositiveInteger (map , BATCH_SIZE , ModelConfigurations .SERVICE_SETTINGS , validationException );
114+ 
109115        if  (responseParserMap  == null  || jsonParserMap  == null ) {
110116            throw  validationException ;
111117        }
@@ -124,7 +130,8 @@ public static CustomServiceSettings fromMap(
124130            queryParams ,
125131            requestContentString ,
126132            responseJsonParser ,
127-             rateLimitSettings 
133+             rateLimitSettings ,
134+             batchSize 
128135        );
129136    }
130137
@@ -142,7 +149,6 @@ public record TextEmbeddingSettings(
142149            null ,
143150            DenseVectorFieldMapper .ElementType .FLOAT 
144151        );
145- 
146152        // This refers to settings that are not related to the text embedding task type (all the settings should be null) 
147153        public  static  final  TextEmbeddingSettings  NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS  = new  TextEmbeddingSettings (null , null , null , null );
148154
@@ -196,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196202    private  final  String  requestContentString ;
197203    private  final  CustomResponseParser  responseJsonParser ;
198204    private  final  RateLimitSettings  rateLimitSettings ;
205+     private  final  int  batchSize ;
199206
200207    public  CustomServiceSettings (
201208        TextEmbeddingSettings  textEmbeddingSettings ,
@@ -205,6 +212,19 @@ public CustomServiceSettings(
205212        String  requestContentString ,
206213        CustomResponseParser  responseJsonParser ,
207214        @ Nullable  RateLimitSettings  rateLimitSettings 
215+     ) {
216+         this (textEmbeddingSettings , url , headers , queryParameters , requestContentString , responseJsonParser , rateLimitSettings , null );
217+     }
218+ 
219+     public  CustomServiceSettings (
220+         TextEmbeddingSettings  textEmbeddingSettings ,
221+         String  url ,
222+         @ Nullable  Map <String , String > headers ,
223+         @ Nullable  QueryParameters  queryParameters ,
224+         String  requestContentString ,
225+         CustomResponseParser  responseJsonParser ,
226+         @ Nullable  RateLimitSettings  rateLimitSettings ,
227+         @ Nullable  Integer  batchSize 
208228    ) {
209229        this .textEmbeddingSettings  = Objects .requireNonNull (textEmbeddingSettings );
210230        this .url  = Objects .requireNonNull (url );
@@ -213,6 +233,7 @@ public CustomServiceSettings(
213233        this .requestContentString  = Objects .requireNonNull (requestContentString );
214234        this .responseJsonParser  = Objects .requireNonNull (responseJsonParser );
215235        this .rateLimitSettings  = Objects .requireNonNullElse (rateLimitSettings , DEFAULT_RATE_LIMIT_SETTINGS );
236+         this .batchSize  = Objects .requireNonNullElse (batchSize , DEFAULT_EMBEDDING_BATCH_SIZE );
216237    }
217238
218239    public  CustomServiceSettings (StreamInput  in ) throws  IOException  {
@@ -223,12 +244,20 @@ public CustomServiceSettings(StreamInput in) throws IOException {
223244        requestContentString  = in .readString ();
224245        responseJsonParser  = in .readNamedWriteable (CustomResponseParser .class );
225246        rateLimitSettings  = new  RateLimitSettings (in );
247+ 
226248        if  (in .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING )
227249            && in .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 ) == false ) {
228250            // Read the error parsing fields for backwards compatibility 
229251            in .readString ();
230252            in .readString ();
231253        }
254+ 
255+         if  (in .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE )
256+             || in .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
257+             batchSize  = in .readVInt ();
258+         } else  {
259+             batchSize  = DEFAULT_EMBEDDING_BATCH_SIZE ;
260+         }
232261    }
233262
234263    @ Override 
@@ -276,6 +305,10 @@ public CustomResponseParser getResponseJsonParser() {
276305        return  responseJsonParser ;
277306    }
278307
308+     public  int  getBatchSize () {
309+         return  batchSize ;
310+     }
311+ 
279312    @ Override 
280313    public  RateLimitSettings  rateLimitSettings () {
281314        return  rateLimitSettings ;
@@ -321,6 +354,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
321354
322355        rateLimitSettings .toXContent (builder , params );
323356
357+         builder .field (BATCH_SIZE , batchSize );
358+ 
324359        return  builder ;
325360    }
326361
@@ -343,12 +378,18 @@ public void writeTo(StreamOutput out) throws IOException {
343378        out .writeString (requestContentString );
344379        out .writeNamedWriteable (responseJsonParser );
345380        rateLimitSettings .writeTo (out );
381+ 
346382        if  (out .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING )
347383            && out .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 ) == false ) {
348384            // Write empty strings for backwards compatibility for the error parsing fields 
349385            out .writeString ("" );
350386            out .writeString ("" );
351387        }
388+ 
389+         if  (out .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE )
390+             || out .getTransportVersion ().isPatchFrom (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
391+             out .writeVInt (batchSize );
392+         }
352393    }
353394
354395    @ Override 
@@ -362,7 +403,8 @@ public boolean equals(Object o) {
362403            && Objects .equals (queryParameters , that .queryParameters )
363404            && Objects .equals (requestContentString , that .requestContentString )
364405            && Objects .equals (responseJsonParser , that .responseJsonParser )
365-             && Objects .equals (rateLimitSettings , that .rateLimitSettings );
406+             && Objects .equals (rateLimitSettings , that .rateLimitSettings )
407+             && Objects .equals (batchSize , that .batchSize );
366408    }
367409
368410    @ Override 
@@ -374,7 +416,8 @@ public int hashCode() {
374416            queryParameters ,
375417            requestContentString ,
376418            responseJsonParser ,
377-             rateLimitSettings 
419+             rateLimitSettings ,
420+             batchSize 
378421        );
379422    }
380423
0 commit comments