2828import org .elasticsearch .cluster .metadata .ProjectMetadata ;
2929import org .elasticsearch .cluster .service .ClusterService ;
3030import org .elasticsearch .common .Strings ;
31+ import org .elasticsearch .common .settings .ClusterSettings ;
3132import org .elasticsearch .common .settings .Settings ;
33+ import org .elasticsearch .common .unit .ByteSizeValue ;
3234import org .elasticsearch .common .xcontent .XContentHelper ;
3335import org .elasticsearch .common .xcontent .support .XContentMapValues ;
3436import org .elasticsearch .index .IndexVersion ;
6668import java .util .List ;
6769import java .util .Map ;
6870import java .util .Optional ;
71+ import java .util .Set ;
6972import java .util .concurrent .CountDownLatch ;
7073import java .util .concurrent .TimeUnit ;
7174
7275import static org .elasticsearch .test .hamcrest .ElasticsearchAssertions .assertToXContentEquivalent ;
7376import static org .elasticsearch .test .hamcrest .ElasticsearchAssertions .awaitLatch ;
74- import static org .elasticsearch .xpack .inference .action .filter .ShardBulkInferenceActionFilter .DEFAULT_BATCH_SIZE ;
77+ import static org .elasticsearch .xpack .inference .action .filter .ShardBulkInferenceActionFilter .INDICES_INFERENCE_BATCH_SIZE ;
7578import static org .elasticsearch .xpack .inference .action .filter .ShardBulkInferenceActionFilter .getIndexRequestOrNull ;
7679import static org .elasticsearch .xpack .inference .mapper .SemanticTextField .getChunksFieldName ;
7780import static org .elasticsearch .xpack .inference .mapper .SemanticTextField .getOriginalTextFieldName ;
@@ -118,7 +121,7 @@ public void tearDownThreadPool() throws Exception {
118121
119122 @ SuppressWarnings ({ "unchecked" , "rawtypes" })
120123 public void testFilterNoop () throws Exception {
121- ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), DEFAULT_BATCH_SIZE , useLegacyFormat , true );
124+ ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), useLegacyFormat , true );
122125 CountDownLatch chainExecuted = new CountDownLatch (1 );
123126 ActionFilterChain actionFilterChain = (task , action , request , listener ) -> {
124127 try {
@@ -144,7 +147,7 @@ public void testFilterNoop() throws Exception {
144147 @ SuppressWarnings ({ "unchecked" , "rawtypes" })
145148 public void testLicenseInvalidForInference () throws InterruptedException {
146149 StaticModel model = StaticModel .createRandomInstance ();
147- ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), DEFAULT_BATCH_SIZE , useLegacyFormat , false );
150+ ShardBulkInferenceActionFilter filter = createFilter (threadPool , Map .of (), useLegacyFormat , false );
148151 CountDownLatch chainExecuted = new CountDownLatch (1 );
149152 ActionFilterChain actionFilterChain = (task , action , request , listener ) -> {
150153 try {
@@ -185,7 +188,6 @@ public void testInferenceNotFound() throws Exception {
185188 ShardBulkInferenceActionFilter filter = createFilter (
186189 threadPool ,
187190 Map .of (model .getInferenceEntityId (), model ),
188- randomIntBetween (1 , 10 ),
189191 useLegacyFormat ,
190192 true
191193 );
@@ -232,7 +234,6 @@ public void testItemFailures() throws Exception {
232234 ShardBulkInferenceActionFilter filter = createFilter (
233235 threadPool ,
234236 Map .of (model .getInferenceEntityId (), model ),
235- randomIntBetween (1 , 10 ),
236237 useLegacyFormat ,
237238 true
238239 );
@@ -303,7 +304,6 @@ public void testExplicitNull() throws Exception {
303304 ShardBulkInferenceActionFilter filter = createFilter (
304305 threadPool ,
305306 Map .of (model .getInferenceEntityId (), model ),
306- randomIntBetween (1 , 10 ),
307307 useLegacyFormat ,
308308 true
309309 );
@@ -374,7 +374,6 @@ public void testHandleEmptyInput() throws Exception {
374374 ShardBulkInferenceActionFilter filter = createFilter (
375375 threadPool ,
376376 Map .of (model .getInferenceEntityId (), model ),
377- randomIntBetween (1 , 10 ),
378377 useLegacyFormat ,
379378 true
380379 );
@@ -447,13 +446,7 @@ public void testManyRandomDocs() throws Exception {
447446 modifiedRequests [id ] = res [1 ];
448447 }
449448
450- ShardBulkInferenceActionFilter filter = createFilter (
451- threadPool ,
452- inferenceModelMap ,
453- randomIntBetween (10 , 30 ),
454- useLegacyFormat ,
455- true
456- );
449+ ShardBulkInferenceActionFilter filter = createFilter (threadPool , inferenceModelMap , useLegacyFormat , true );
457450 CountDownLatch chainExecuted = new CountDownLatch (1 );
458451 ActionFilterChain actionFilterChain = (task , action , request , listener ) -> {
459452 try {
@@ -487,7 +480,6 @@ public void testManyRandomDocs() throws Exception {
487480 private static ShardBulkInferenceActionFilter createFilter (
488481 ThreadPool threadPool ,
489482 Map <String , StaticModel > modelMap ,
490- int batchSize ,
491483 boolean useLegacyFormat ,
492484 boolean isLicenseValidForInference
493485 ) {
@@ -554,18 +546,17 @@ private static ShardBulkInferenceActionFilter createFilter(
554546 createClusterService (useLegacyFormat ),
555547 inferenceServiceRegistry ,
556548 modelRegistry ,
557- licenseState ,
558- batchSize
549+ licenseState
559550 );
560551 }
561552
562553 private static ClusterService createClusterService (boolean useLegacyFormat ) {
563554 IndexMetadata indexMetadata = mock (IndexMetadata .class );
564- var settings = Settings .builder ()
555+ var indexSettings = Settings .builder ()
565556 .put (IndexMetadata .SETTING_INDEX_VERSION_CREATED .getKey (), IndexVersion .current ())
566557 .put (InferenceMetadataFieldsMapper .USE_LEGACY_SEMANTIC_TEXT_FORMAT .getKey (), useLegacyFormat )
567558 .build ();
568- when (indexMetadata .getSettings ()).thenReturn (settings );
559+ when (indexMetadata .getSettings ()).thenReturn (indexSettings );
569560
570561 ProjectMetadata project = spy (ProjectMetadata .builder (Metadata .DEFAULT_PROJECT_ID ).build ());
571562 when (project .index (anyString ())).thenReturn (indexMetadata );
@@ -576,7 +567,10 @@ private static ClusterService createClusterService(boolean useLegacyFormat) {
576567 ClusterState clusterState = ClusterState .builder (new ClusterName ("test" )).metadata (metadata ).build ();
577568 ClusterService clusterService = mock (ClusterService .class );
578569 when (clusterService .state ()).thenReturn (clusterState );
579-
570+ long batchSizeInBytes = randomLongBetween (0 , ByteSizeValue .ofKb (1 ).getBytes ());
571+ Settings settings = Settings .builder ().put (INDICES_INFERENCE_BATCH_SIZE .getKey (), ByteSizeValue .ofBytes (batchSizeInBytes )).build ();
572+ when (clusterService .getSettings ()).thenReturn (settings );
573+ when (clusterService .getClusterSettings ()).thenReturn (new ClusterSettings (settings , Set .of (INDICES_INFERENCE_BATCH_SIZE )));
580574 return clusterService ;
581575 }
582576
@@ -587,7 +581,8 @@ private static BulkItemRequest[] randomBulkItemRequest(
587581 ) throws IOException {
588582 Map <String , Object > docMap = new LinkedHashMap <>();
589583 Map <String , Object > expectedDocMap = new LinkedHashMap <>();
590- XContentType requestContentType = randomFrom (XContentType .values ());
584+ // force JSON to avoid double/float conversions
585+ XContentType requestContentType = XContentType .JSON ;
591586
592587 Map <String , Object > inferenceMetadataFields = new HashMap <>();
593588 for (var entry : fieldInferenceMap .values ()) {
0 commit comments