2626import org .elasticsearch .index .mapper .SourceFieldMapper ;
2727import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
2828import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapperTestUtils ;
29+ import org .elasticsearch .inference .Model ;
2930import org .elasticsearch .inference .SimilarityMeasure ;
3031import org .elasticsearch .license .LicenseSettings ;
3132import org .elasticsearch .plugins .Plugin ;
3233import org .elasticsearch .search .builder .SearchSourceBuilder ;
3334import org .elasticsearch .test .ESIntegTestCase ;
35+ import org .elasticsearch .test .InternalTestCluster ;
36+ import org .elasticsearch .xpack .inference .InferenceIndex ;
3437import org .elasticsearch .xpack .inference .LocalStateInferencePlugin ;
35- import org .elasticsearch .xpack .inference .Utils ;
3638import org .elasticsearch .xpack .inference .mock .TestDenseInferenceServiceExtension ;
3739import org .elasticsearch .xpack .inference .mock .TestSparseInferenceServiceExtension ;
3840import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
4547import java .util .Locale ;
4648import java .util .Map ;
4749import java .util .Set ;
50+ import java .util .function .Function ;
4851
52+ import static org .elasticsearch .xpack .inference .Utils .storeDenseModel ;
53+ import static org .elasticsearch .xpack .inference .Utils .storeModel ;
54+ import static org .elasticsearch .xpack .inference .Utils .storeSparseModel ;
4955import static org .elasticsearch .xpack .inference .action .filter .ShardBulkInferenceActionFilter .INDICES_INFERENCE_BATCH_SIZE ;
5056import static org .elasticsearch .xpack .inference .mapper .SemanticTextFieldTests .randomSemanticTextInput ;
5157import static org .hamcrest .Matchers .containsString ;
@@ -56,6 +62,7 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
5662
5763 private final boolean useLegacyFormat ;
5864 private final boolean useSyntheticSource ;
65+ private ModelRegistry modelRegistry ;
5966
6067 public ShardBulkInferenceActionFilterIT (boolean useLegacyFormat , boolean useSyntheticSource ) {
6168 this .useLegacyFormat = useLegacyFormat ;
@@ -74,16 +81,16 @@ public static Iterable<Object[]> parameters() throws Exception {
7481
7582 @ Before
7683 public void setup () throws Exception {
77- ModelRegistry modelRegistry = internalCluster ().getCurrentMasterNodeInstance (ModelRegistry .class );
84+ modelRegistry = internalCluster ().getCurrentMasterNodeInstance (ModelRegistry .class );
7885 DenseVectorFieldMapper .ElementType elementType = randomFrom (DenseVectorFieldMapper .ElementType .values ());
7986 // dot product means that we need normalized vectors; it's not worth doing that in this test
8087 SimilarityMeasure similarity = randomValueOtherThan (
8188 SimilarityMeasure .DOT_PRODUCT ,
8289 () -> randomFrom (DenseVectorFieldMapperTestUtils .getSupportedSimilarities (elementType ))
8390 );
8491 int dimensions = DenseVectorFieldMapperTestUtils .randomCompatibleDimensions (elementType , 100 );
85- Utils . storeSparseModel (modelRegistry );
86- Utils . storeDenseModel (modelRegistry , dimensions , similarity , elementType );
92+ storeSparseModel (modelRegistry );
93+ storeDenseModel (modelRegistry , dimensions , similarity , elementType );
8794 }
8895
8996 @ Override
@@ -135,32 +142,131 @@ public void testBulkOperations() throws Exception {
135142 TestDenseInferenceServiceExtension .TestInferenceService .NAME
136143 )
137144 ).get ();
145+ assertRandomBulkOperations (INDEX_NAME , isIndexRequest -> {
146+ Map <String , Object > map = new HashMap <>();
147+ map .put ("sparse_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
148+ map .put ("dense_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
149+ return map ;
150+ });
151+ }
152+
153+ public void testItemFailures () {
154+ prepareCreate (INDEX_NAME ).setMapping (
155+ String .format (
156+ Locale .ROOT ,
157+ """
158+ {
159+ "properties": {
160+ "sparse_field": {
161+ "type": "semantic_text",
162+ "inference_id": "%s"
163+ },
164+ "dense_field": {
165+ "type": "semantic_text",
166+ "inference_id": "%s"
167+ }
168+ }
169+ }
170+ """ ,
171+ TestSparseInferenceServiceExtension .TestInferenceService .NAME ,
172+ TestDenseInferenceServiceExtension .TestInferenceService .NAME
173+ )
174+ ).get ();
175+
176+ BulkRequestBuilder bulkReqBuilder = client ().prepareBulk ();
177+ int totalBulkSize = randomIntBetween (100 , 200 ); // Use a bulk request size large enough to require batching
178+ for (int bulkSize = 0 ; bulkSize < totalBulkSize ; bulkSize ++) {
179+ String id = Integer .toString (bulkSize );
180+
181+ // Set field values that will cause errors when generating inference requests
182+ Map <String , Object > source = new HashMap <>();
183+ source .put ("sparse_field" , List .of (Map .of ("foo" , "bar" ), Map .of ("baz" , "bar" )));
184+ source .put ("dense_field" , List .of (Map .of ("foo" , "bar" ), Map .of ("baz" , "bar" )));
185+
186+ bulkReqBuilder .add (new IndexRequestBuilder (client ()).setIndex (INDEX_NAME ).setId (id ).setSource (source ));
187+ }
188+
189+ BulkResponse bulkResponse = bulkReqBuilder .get ();
190+ assertThat (bulkResponse .hasFailures (), equalTo (true ));
191+ for (BulkItemResponse bulkItemResponse : bulkResponse .getItems ()) {
192+ assertThat (bulkItemResponse .isFailed (), equalTo (true ));
193+ assertThat (bulkItemResponse .getFailureMessage (), containsString ("expected [String|Number|Boolean]" ));
194+ }
195+ }
196+
197+ public void testRestart () throws Exception {
198+ Model model1 = new TestSparseInferenceServiceExtension .TestSparseModel (
199+ "another_inference_endpoint" ,
200+ new TestSparseInferenceServiceExtension .TestServiceSettings ("sparse_model" , null , false )
201+ );
202+ storeModel (modelRegistry , model1 );
203+ prepareCreate ("index_restart" ).setMapping ("""
204+ {
205+ "properties": {
206+ "sparse_field": {
207+ "type": "semantic_text",
208+ "inference_id": "new_inference_endpoint"
209+ },
210+ "other_field": {
211+ "type": "semantic_text",
212+ "inference_id": "another_inference_endpoint"
213+ }
214+ }
215+ }
216+ """ ).get ();
217+ Model model2 = new TestSparseInferenceServiceExtension .TestSparseModel (
218+ "new_inference_endpoint" ,
219+ new TestSparseInferenceServiceExtension .TestServiceSettings ("sparse_model" , null , false )
220+ );
221+ storeModel (modelRegistry , model2 );
222+
223+ internalCluster ().fullRestart (new InternalTestCluster .RestartCallback ());
224+ ensureGreen (InferenceIndex .INDEX_NAME , "index_restart" );
138225
226+ assertRandomBulkOperations ("index_restart" , isIndexRequest -> {
227+ Map <String , Object > map = new HashMap <>();
228+ map .put ("sparse_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
229+ map .put ("other_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
230+ return map ;
231+ });
232+
233+ internalCluster ().fullRestart (new InternalTestCluster .RestartCallback ());
234+ ensureGreen (InferenceIndex .INDEX_NAME , "index_restart" );
235+
236+ assertRandomBulkOperations ("index_restart" , isIndexRequest -> {
237+ Map <String , Object > map = new HashMap <>();
238+ map .put ("sparse_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
239+ map .put ("other_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
240+ return map ;
241+ });
242+ }
243+
244+ private void assertRandomBulkOperations (String indexName , Function <Boolean , Map <String , Object >> sourceSupplier ) throws Exception {
245+ int numHits = numHits (indexName );
139246 int totalBulkReqs = randomIntBetween (2 , 100 );
140- long totalDocs = 0 ;
247+ long totalDocs = numHits ;
141248 Set <String > ids = new HashSet <>();
142- for (int bulkReqs = 0 ; bulkReqs < totalBulkReqs ; bulkReqs ++) {
249+
250+ for (int bulkReqs = numHits ; bulkReqs < totalBulkReqs ; bulkReqs ++) {
143251 BulkRequestBuilder bulkReqBuilder = client ().prepareBulk ();
144252 int totalBulkSize = randomIntBetween (1 , 100 );
145253 for (int bulkSize = 0 ; bulkSize < totalBulkSize ; bulkSize ++) {
146254 if (ids .size () > 0 && rarely (random ())) {
147255 String id = randomFrom (ids );
148256 ids .remove (id );
149- DeleteRequestBuilder request = new DeleteRequestBuilder (client (), INDEX_NAME ).setId (id );
257+ DeleteRequestBuilder request = new DeleteRequestBuilder (client (), indexName ).setId (id );
150258 bulkReqBuilder .add (request );
151259 continue ;
152260 }
153261 String id = Long .toString (totalDocs ++);
154262 boolean isIndexRequest = randomBoolean ();
155- Map <String , Object > source = new HashMap <>();
156- source .put ("sparse_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
157- source .put ("dense_field" , isIndexRequest && rarely () ? null : randomSemanticTextInput ());
263+ Map <String , Object > source = sourceSupplier .apply (isIndexRequest );
158264 if (isIndexRequest ) {
159- bulkReqBuilder .add (new IndexRequestBuilder (client ()).setIndex (INDEX_NAME ).setId (id ).setSource (source ));
265+ bulkReqBuilder .add (new IndexRequestBuilder (client ()).setIndex (indexName ).setId (id ).setSource (source ));
160266 ids .add (id );
161267 } else {
162268 boolean isUpsert = randomBoolean ();
163- UpdateRequestBuilder request = new UpdateRequestBuilder (client ()).setIndex (INDEX_NAME ).setDoc (source );
269+ UpdateRequestBuilder request = new UpdateRequestBuilder (client ()).setIndex (indexName ).setDoc (source );
164270 if (isUpsert || ids .size () == 0 ) {
165271 request .setDocAsUpsert (true );
166272 } else {
@@ -188,59 +294,17 @@ public void testBulkOperations() throws Exception {
188294 }
189295 assertFalse (bulkResponse .hasFailures ());
190296 }
297+ client ().admin ().indices ().refresh (new RefreshRequest (indexName )).get ();
298+ assertThat (numHits (indexName ), equalTo (ids .size () + numHits ));
299+ }
191300
192- client ().admin ().indices ().refresh (new RefreshRequest (INDEX_NAME )).get ();
193-
301+ private int numHits (String indexName ) throws Exception {
194302 SearchSourceBuilder sourceBuilder = new SearchSourceBuilder ().size (0 ).trackTotalHits (true );
195- SearchResponse searchResponse = client ().search (new SearchRequest (INDEX_NAME ).source (sourceBuilder )).get ();
303+ SearchResponse searchResponse = client ().search (new SearchRequest (indexName ).source (sourceBuilder )).get ();
196304 try {
197- assertThat ( searchResponse .getHits ().getTotalHits ().value , equalTo (( long ) ids . size ())) ;
305+ return ( int ) searchResponse .getHits ().getTotalHits ().value ;
198306 } finally {
199307 searchResponse .decRef ();
200308 }
201309 }
202-
203- public void testItemFailures () {
204- prepareCreate (INDEX_NAME ).setMapping (
205- String .format (
206- Locale .ROOT ,
207- """
208- {
209- "properties": {
210- "sparse_field": {
211- "type": "semantic_text",
212- "inference_id": "%s"
213- },
214- "dense_field": {
215- "type": "semantic_text",
216- "inference_id": "%s"
217- }
218- }
219- }
220- """ ,
221- TestSparseInferenceServiceExtension .TestInferenceService .NAME ,
222- TestDenseInferenceServiceExtension .TestInferenceService .NAME
223- )
224- ).get ();
225-
226- BulkRequestBuilder bulkReqBuilder = client ().prepareBulk ();
227- int totalBulkSize = randomIntBetween (100 , 200 ); // Use a bulk request size large enough to require batching
228- for (int bulkSize = 0 ; bulkSize < totalBulkSize ; bulkSize ++) {
229- String id = Integer .toString (bulkSize );
230-
231- // Set field values that will cause errors when generating inference requests
232- Map <String , Object > source = new HashMap <>();
233- source .put ("sparse_field" , List .of (Map .of ("foo" , "bar" ), Map .of ("baz" , "bar" )));
234- source .put ("dense_field" , List .of (Map .of ("foo" , "bar" ), Map .of ("baz" , "bar" )));
235-
236- bulkReqBuilder .add (new IndexRequestBuilder (client ()).setIndex (INDEX_NAME ).setId (id ).setSource (source ));
237- }
238-
239- BulkResponse bulkResponse = bulkReqBuilder .get ();
240- assertThat (bulkResponse .hasFailures (), equalTo (true ));
241- for (BulkItemResponse bulkItemResponse : bulkResponse .getItems ()) {
242- assertThat (bulkItemResponse .isFailed (), equalTo (true ));
243- assertThat (bulkItemResponse .getFailureMessage (), containsString ("expected [String|Number|Boolean]" ));
244- }
245- }
246310}
0 commit comments