99
1010import com .carrotsearch .randomizedtesting .annotations .ParametersFactory ;
1111
12+ import org .apache .lucene .document .Document ;
13+ import org .apache .lucene .document .Field ;
14+ import org .apache .lucene .document .FloatDocValuesField ;
15+ import org .apache .lucene .document .TextField ;
1216import org .apache .lucene .search .BooleanClause ;
1317import org .apache .lucene .search .BooleanQuery ;
1418import org .apache .lucene .search .BoostQuery ;
1721import org .apache .lucene .search .MatchNoDocsQuery ;
1822import org .apache .lucene .search .Query ;
1923import org .apache .lucene .search .join .ScoreMode ;
24+ import org .apache .lucene .tests .index .RandomIndexWriter ;
2025import org .elasticsearch .action .ActionListener ;
2126import org .elasticsearch .action .ActionRequest ;
2227import org .elasticsearch .action .ActionType ;
3035import org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
3136import org .elasticsearch .common .settings .Settings ;
3237import org .elasticsearch .core .IOUtils ;
38+ import org .elasticsearch .core .Nullable ;
3339import org .elasticsearch .index .IndexVersion ;
3440import org .elasticsearch .index .mapper .InferenceMetadataFieldsMapper ;
3541import org .elasticsearch .index .mapper .MapperService ;
@@ -99,6 +105,7 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
99105 private static DenseVectorFieldMapper .ElementType denseVectorElementType ;
100106 private static boolean useSearchInferenceId ;
101107 private final boolean useLegacyFormat ;
108+ private MapperService currentMapperService ;
102109
103110 private enum InferenceResultType {
104111 NONE ,
@@ -180,6 +187,22 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws
180187 applyRandomInferenceResults (mapperService );
181188 }
182189
190+ @ Override
191+ protected IndexReaderManager getIndexReaderManager () {
192+ return new IndexReaderManager () {
193+ @ Override
194+ protected void initIndexWriter (RandomIndexWriter indexWriter ) {
195+ Document document = new Document ();
196+ document .add (new TextField ("semantic.inference.chunks.embeddings" , "a b x y" , Field .Store .NO ));
197+ try {
198+ indexWriter .addDocument (document );
199+ } catch (IOException e ) {
200+ throw new RuntimeException (e );
201+ }
202+ }
203+ };
204+ }
205+
183206 private void applyRandomInferenceResults (MapperService mapperService ) throws IOException {
184207 // Parse random inference results (or no inference results) to set up the dynamic inference result mappings under the semantic text
185208 // field
@@ -240,12 +263,8 @@ private void assertSparseEmbeddingLuceneQuery(Query query) {
240263 assertThat (sparseQuery .getTermsQuery (), instanceOf (BooleanQuery .class ));
241264
242265 BooleanQuery innerBooleanQuery = (BooleanQuery ) sparseQuery .getTermsQuery ();
243- assertThat (innerBooleanQuery .clauses ().size (), equalTo (queryTokenCount ));
244- innerBooleanQuery .forEach (c -> {
245- assertThat (c .occur (), equalTo (SHOULD ));
246- assertThat (c .query (), instanceOf (BoostQuery .class ));
247- assertThat (((BoostQuery ) c .query ()).getBoost (), equalTo (TOKEN_WEIGHT ));
248- });
266+ // no clauses as tokens would be pruned
267+ assertThat (innerBooleanQuery .clauses ().size (), equalTo (0 ));
249268 }
250269
251270 private void assertTextEmbeddingLuceneQuery (Query query ) {
@@ -376,18 +395,7 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults(
376395 DenseVectorFieldMapper .ElementType denseVectorElementType ,
377396 boolean useLegacyFormat
378397 ) throws IOException {
379- var modelSettings = switch (inferenceResultType ) {
380- case NONE -> null ;
381- case SPARSE_EMBEDDING -> new MinimalServiceSettings ("my-service" , TaskType .SPARSE_EMBEDDING , null , null , null );
382- case TEXT_EMBEDDING -> new MinimalServiceSettings (
383- "my-service" ,
384- TaskType .TEXT_EMBEDDING ,
385- TEXT_EMBEDDING_DIMENSION_COUNT ,
386- // l2_norm similarity is required for bit embeddings
387- denseVectorElementType == DenseVectorFieldMapper .ElementType .BIT ? SimilarityMeasure .L2_NORM : SimilarityMeasure .COSINE ,
388- denseVectorElementType
389- );
390- };
398+ var modelSettings = getModelSettingsForInferenceResultType (inferenceResultType , denseVectorElementType );
391399
392400 SourceToParse sourceToParse = null ;
393401 if (modelSettings != null ) {
@@ -414,6 +422,23 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults(
414422 return sourceToParse ;
415423 }
416424
425+ private static MinimalServiceSettings getModelSettingsForInferenceResultType (
426+ InferenceResultType inferenceResultType , @ Nullable DenseVectorFieldMapper .ElementType denseVectorElementType
427+ ) {
428+ return switch (inferenceResultType ) {
429+ case NONE -> null ;
430+ case SPARSE_EMBEDDING -> new MinimalServiceSettings ("my-service" , TaskType .SPARSE_EMBEDDING , null , null , null );
431+ case TEXT_EMBEDDING -> new MinimalServiceSettings (
432+ "my-service" ,
433+ TaskType .TEXT_EMBEDDING ,
434+ TEXT_EMBEDDING_DIMENSION_COUNT ,
435+ // l2_norm similarity is required for bit embeddings
436+ denseVectorElementType == DenseVectorFieldMapper .ElementType .BIT ? SimilarityMeasure .L2_NORM : SimilarityMeasure .COSINE ,
437+ denseVectorElementType
438+ );
439+ };
440+ }
441+
417442 public static class FakeMlPlugin extends Plugin {
418443 @ Override
419444 public List <NamedWriteableRegistry .Entry > getNamedWriteables () {
0 commit comments