Skip to content

Commit 7f40f95

Browse files
committed
override getIndexReaderManager for SemanticQueryBuilderTests
1 parent 7a68727 commit 7f40f95

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1111

12+
import org.apache.lucene.document.Document;
13+
import org.apache.lucene.document.Field;
14+
import org.apache.lucene.document.FloatDocValuesField;
15+
import org.apache.lucene.document.TextField;
1216
import org.apache.lucene.search.BooleanClause;
1317
import org.apache.lucene.search.BooleanQuery;
1418
import org.apache.lucene.search.BoostQuery;
@@ -17,6 +21,7 @@
1721
import org.apache.lucene.search.MatchNoDocsQuery;
1822
import org.apache.lucene.search.Query;
1923
import org.apache.lucene.search.join.ScoreMode;
24+
import org.apache.lucene.tests.index.RandomIndexWriter;
2025
import org.elasticsearch.action.ActionListener;
2126
import org.elasticsearch.action.ActionRequest;
2227
import org.elasticsearch.action.ActionType;
@@ -30,6 +35,7 @@
3035
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
3136
import org.elasticsearch.common.settings.Settings;
3237
import org.elasticsearch.core.IOUtils;
38+
import org.elasticsearch.core.Nullable;
3339
import org.elasticsearch.index.IndexVersion;
3440
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
3541
import org.elasticsearch.index.mapper.MapperService;
@@ -99,6 +105,7 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
99105
private static DenseVectorFieldMapper.ElementType denseVectorElementType;
100106
private static boolean useSearchInferenceId;
101107
private final boolean useLegacyFormat;
108+
private MapperService currentMapperService;
102109

103110
private enum InferenceResultType {
104111
NONE,
@@ -180,6 +187,22 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws
180187
applyRandomInferenceResults(mapperService);
181188
}
182189

190+
@Override
191+
protected IndexReaderManager getIndexReaderManager() {
192+
return new IndexReaderManager() {
193+
@Override
194+
protected void initIndexWriter(RandomIndexWriter indexWriter) {
195+
Document document = new Document();
196+
document.add(new TextField("semantic.inference.chunks.embeddings", "a b x y", Field.Store.NO));
197+
try {
198+
indexWriter.addDocument(document);
199+
} catch (IOException e) {
200+
throw new RuntimeException(e);
201+
}
202+
}
203+
};
204+
}
205+
183206
private void applyRandomInferenceResults(MapperService mapperService) throws IOException {
184207
// Parse random inference results (or no inference results) to set up the dynamic inference result mappings under the semantic text
185208
// field
@@ -240,12 +263,8 @@ private void assertSparseEmbeddingLuceneQuery(Query query) {
240263
assertThat(sparseQuery.getTermsQuery(), instanceOf(BooleanQuery.class));
241264

242265
BooleanQuery innerBooleanQuery = (BooleanQuery) sparseQuery.getTermsQuery();
243-
assertThat(innerBooleanQuery.clauses().size(), equalTo(queryTokenCount));
244-
innerBooleanQuery.forEach(c -> {
245-
assertThat(c.occur(), equalTo(SHOULD));
246-
assertThat(c.query(), instanceOf(BoostQuery.class));
247-
assertThat(((BoostQuery) c.query()).getBoost(), equalTo(TOKEN_WEIGHT));
248-
});
266+
// no clauses as tokens would be pruned
267+
assertThat(innerBooleanQuery.clauses().size(), equalTo(0));
249268
}
250269

251270
private void assertTextEmbeddingLuceneQuery(Query query) {
@@ -376,18 +395,7 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults(
376395
DenseVectorFieldMapper.ElementType denseVectorElementType,
377396
boolean useLegacyFormat
378397
) throws IOException {
379-
var modelSettings = switch (inferenceResultType) {
380-
case NONE -> null;
381-
case SPARSE_EMBEDDING -> new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null);
382-
case TEXT_EMBEDDING -> new MinimalServiceSettings(
383-
"my-service",
384-
TaskType.TEXT_EMBEDDING,
385-
TEXT_EMBEDDING_DIMENSION_COUNT,
386-
// l2_norm similarity is required for bit embeddings
387-
denseVectorElementType == DenseVectorFieldMapper.ElementType.BIT ? SimilarityMeasure.L2_NORM : SimilarityMeasure.COSINE,
388-
denseVectorElementType
389-
);
390-
};
398+
var modelSettings = getModelSettingsForInferenceResultType(inferenceResultType, denseVectorElementType);
391399

392400
SourceToParse sourceToParse = null;
393401
if (modelSettings != null) {
@@ -414,6 +422,23 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults(
414422
return sourceToParse;
415423
}
416424

425+
private static MinimalServiceSettings getModelSettingsForInferenceResultType(
426+
InferenceResultType inferenceResultType, @Nullable DenseVectorFieldMapper.ElementType denseVectorElementType
427+
) {
428+
return switch (inferenceResultType) {
429+
case NONE -> null;
430+
case SPARSE_EMBEDDING -> new MinimalServiceSettings("my-service", TaskType.SPARSE_EMBEDDING, null, null, null);
431+
case TEXT_EMBEDDING -> new MinimalServiceSettings(
432+
"my-service",
433+
TaskType.TEXT_EMBEDDING,
434+
TEXT_EMBEDDING_DIMENSION_COUNT,
435+
// l2_norm similarity is required for bit embeddings
436+
denseVectorElementType == DenseVectorFieldMapper.ElementType.BIT ? SimilarityMeasure.L2_NORM : SimilarityMeasure.COSINE,
437+
denseVectorElementType
438+
);
439+
};
440+
}
441+
417442
public static class FakeMlPlugin extends Plugin {
418443
@Override
419444
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {

0 commit comments

Comments
 (0)