2727import org .elasticsearch .cluster .metadata .Metadata ;
2828import org .elasticsearch .cluster .service .ClusterService ;
2929import org .elasticsearch .common .Strings ;
30+ import org .elasticsearch .common .settings .ClusterSettings ;
3031import org .elasticsearch .common .settings .Settings ;
32+ import org .elasticsearch .common .unit .ByteSizeValue ;
3133import org .elasticsearch .common .xcontent .XContentHelper ;
3234import org .elasticsearch .common .xcontent .support .XContentMapValues ;
3335import org .elasticsearch .index .IndexVersion ;
6567import java .util .List ;
6668import java .util .Map ;
6769import java .util .Optional ;
70+ import java .util .Set ;
6871import java .util .concurrent .CountDownLatch ;
6972import java .util .concurrent .TimeUnit ;
7073
7174import static org .elasticsearch .test .hamcrest .ElasticsearchAssertions .assertToXContentEquivalent ;
7275import static org .elasticsearch .test .hamcrest .ElasticsearchAssertions .awaitLatch ;
73- import static org .elasticsearch .xpack .inference .action .filter .ShardBulkInferenceActionFilter .DEFAULT_BATCH_SIZE ;
76+ import static org .elasticsearch .xpack .inference .action .filter .ShardBulkInferenceActionFilter .INDICES_INFERENCE_BATCH_SIZE ;
7477import static org .elasticsearch .xpack .inference .action .filter .ShardBulkInferenceActionFilter .getIndexRequestOrNull ;
7578import static org .elasticsearch .xpack .inference .mapper .SemanticTextField .getChunksFieldName ;
7679import static org .elasticsearch .xpack .inference .mapper .SemanticTextField .getOriginalTextFieldName ;
@@ -115,7 +118,7 @@ public void tearDownThreadPool() throws Exception {
115118
116119 @ SuppressWarnings ({ "unchecked" , "rawtypes" })
117120 public void testFilterNoop () throws Exception {
118- ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), DEFAULT_BATCH_SIZE , useLegacyFormat , true );
121+ ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), useLegacyFormat , true );
119122 CountDownLatch chainExecuted = new CountDownLatch (1 );
120123 ActionFilterChain actionFilterChain = (task , action , request , listener ) -> {
121124 try {
@@ -141,7 +144,7 @@ public void testFilterNoop() throws Exception {
141144 @ SuppressWarnings ({ "unchecked" , "rawtypes" })
142145 public void testLicenseInvalidForInference () throws InterruptedException {
143146 StaticModel model = StaticModel .createRandomInstance ();
144- ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), DEFAULT_BATCH_SIZE , useLegacyFormat , false );
147+ ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), useLegacyFormat , false );
145148 CountDownLatch chainExecuted = new CountDownLatch (1 );
146149 ActionFilterChain actionFilterChain = (task , action , request , listener ) -> {
147150 try {
@@ -182,7 +185,6 @@ public void testInferenceNotFound() throws Exception {
182185 ShardBulkInferenceActionFilter filter = createFilter (
183186 threadPool ,
184187 Map .of (model .getInferenceEntityId (), model ),
185- randomIntBetween (1 , 10 ),
186188 useLegacyFormat ,
187189 true
188190 );
@@ -229,7 +231,6 @@ public void testItemFailures() throws Exception {
229231 ShardBulkInferenceActionFilter filter = createFilter (
230232 threadPool ,
231233 Map .of (model .getInferenceEntityId (), model ),
232- randomIntBetween (1 , 10 ),
233234 useLegacyFormat ,
234235 true
235236 );
@@ -300,7 +301,6 @@ public void testExplicitNull() throws Exception {
300301 ShardBulkInferenceActionFilter filter = createFilter (
301302 threadPool ,
302303 Map .of (model .getInferenceEntityId (), model ),
303- randomIntBetween (1 , 10 ),
304304 useLegacyFormat ,
305305 true
306306 );
@@ -371,7 +371,6 @@ public void testHandleEmptyInput() throws Exception {
371371 ShardBulkInferenceActionFilter filter = createFilter (
372372 threadPool ,
373373 Map .of (model .getInferenceEntityId (), model ),
374- randomIntBetween (1 , 10 ),
375374 useLegacyFormat ,
376375 true
377376 );
@@ -444,13 +443,7 @@ public void testManyRandomDocs() throws Exception {
444443 modifiedRequests [id ] = res [1 ];
445444 }
446445
447- ShardBulkInferenceActionFilter filter = createFilter (
448- threadPool ,
449- inferenceModelMap ,
450- randomIntBetween (10 , 30 ),
451- useLegacyFormat ,
452- true
453- );
446+ ShardBulkInferenceActionFilter filter = createFilter (threadPool , inferenceModelMap , useLegacyFormat , true );
454447 CountDownLatch chainExecuted = new CountDownLatch (1 );
455448 ActionFilterChain actionFilterChain = (task , action , request , listener ) -> {
456449 try {
@@ -484,7 +477,6 @@ public void testManyRandomDocs() throws Exception {
484477 private static ShardBulkInferenceActionFilter createFilter (
485478 ThreadPool threadPool ,
486479 Map <String , StaticModel > modelMap ,
487- int batchSize ,
488480 boolean useLegacyFormat ,
489481 boolean isLicenseValidForInference
490482 ) {
@@ -551,26 +543,28 @@ private static ShardBulkInferenceActionFilter createFilter(
551543 createClusterService (useLegacyFormat ),
552544 inferenceServiceRegistry ,
553545 modelRegistry ,
554- licenseState ,
555- batchSize
546+ licenseState
556547 );
557548 }
558549
559550 private static ClusterService createClusterService (boolean useLegacyFormat ) {
560551 IndexMetadata indexMetadata = mock (IndexMetadata .class );
561- var settings = Settings .builder ()
552+ var indexSettings = Settings .builder ()
562553 .put (IndexMetadata .SETTING_INDEX_VERSION_CREATED .getKey (), IndexVersion .current ())
563554 .put (InferenceMetadataFieldsMapper .USE_LEGACY_SEMANTIC_TEXT_FORMAT .getKey (), useLegacyFormat )
564555 .build ();
565- when (indexMetadata .getSettings ()).thenReturn (settings );
556+ when (indexMetadata .getSettings ()).thenReturn (indexSettings );
566557
567558 Metadata metadata = mock (Metadata .class );
568559 when (metadata .index (any (String .class ))).thenReturn (indexMetadata );
569560
570561 ClusterState clusterState = ClusterState .builder (new ClusterName ("test" )).metadata (metadata ).build ();
571562 ClusterService clusterService = mock (ClusterService .class );
572563 when (clusterService .state ()).thenReturn (clusterState );
573-
564+ long batchSizeInBytes = randomLongBetween (0 , ByteSizeValue .ofKb (1 ).getBytes ());
565+ Settings settings = Settings .builder ().put (INDICES_INFERENCE_BATCH_SIZE .getKey (), ByteSizeValue .ofBytes (batchSizeInBytes )).build ();
566+ when (clusterService .getSettings ()).thenReturn (settings );
567+ when (clusterService .getClusterSettings ()).thenReturn (new ClusterSettings (settings , Set .of (INDICES_INFERENCE_BATCH_SIZE )));
574568 return clusterService ;
575569 }
576570
@@ -581,7 +575,8 @@ private static BulkItemRequest[] randomBulkItemRequest(
581575 ) throws IOException {
582576 Map <String , Object > docMap = new LinkedHashMap <>();
583577 Map <String , Object > expectedDocMap = new LinkedHashMap <>();
584- XContentType requestContentType = randomFrom (XContentType .values ());
578+ // force JSON to avoid double/float conversions
579+ XContentType requestContentType = XContentType .JSON ;
585580
586581 Map <String , Object > inferenceMetadataFields = new HashMap <>();
587582 for (var entry : fieldInferenceMap .values ()) {
0 commit comments