Skip to content

Commit 9b96d2f

Browse files
committed
refactor(test): add sparse vector pruning tests (#132264)
(cherry picked from commit 6cd330c) # Conflicts: # server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java
1 parent 472e913 commit 9b96d2f

File tree

4 files changed

+193
-176
lines changed

4 files changed

+193
-176
lines changed

server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java

Lines changed: 147 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.lucene.index.LeafReader;
1717
import org.apache.lucene.search.BooleanClause;
1818
import org.apache.lucene.search.BooleanQuery;
19+
import org.apache.lucene.search.BoostQuery;
1920
import org.apache.lucene.search.IndexSearcher;
2021
import org.apache.lucene.search.Query;
2122
import org.apache.lucene.store.Directory;
@@ -53,6 +54,10 @@
5354
import java.util.LinkedHashMap;
5455
import java.util.List;
5556
import java.util.Map;
57+
import java.util.Set;
58+
import java.util.TreeMap;
59+
import java.util.stream.Collectors;
60+
import java.util.stream.Stream;
5661

5762
import static org.elasticsearch.index.IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT;
5863
import static org.elasticsearch.index.IndexVersions.UPGRADE_TO_LUCENE_10_0_0;
@@ -67,6 +72,22 @@
6772

6873
public class SparseVectorFieldMapperTests extends MapperTestCase {
6974

75+
public static final float STRICT_TOKENS_WEIGHT_THRESHOLD = 0.5f;
76+
public static final float STRICT_TOKENS_FREQ_RATIO_THRESHOLD = 1;
77+
78+
private static final Map<String, Float> COMMON_TOKENS = Map.of(
79+
"common1_drop_default",
80+
0.1f,
81+
"common2_drop_default",
82+
0.1f,
83+
"common3_drop_default",
84+
0.1f
85+
);
86+
87+
private static final Map<String, Float> MEDIUM_TOKENS = Map.of("medium1_keep_strict", 0.5f, "medium2_keep_default", 0.25f);
88+
89+
private static final Map<String, Float> RARE_TOKENS = Map.of("rare1_keep_strict", 0.9f, "rare2_keep_strict", 0.85f);
90+
7091
@Override
7192
protected Object getSampleValueForDocument() {
7293
Map<String, Float> map = new LinkedHashMap<>();
@@ -122,7 +143,7 @@ protected void minimalMappingWithExplicitIndexOptions(XContentBuilder b) throws
122143
b.field("prune", true);
123144
b.startObject("pruning_config");
124145
{
125-
b.field("tokens_freq_ratio_threshold", 3.0f);
146+
b.field("tokens_freq_ratio_threshold", 1.0f);
126147
b.field("tokens_weight_threshold", 0.5f);
127148
}
128149
b.endObject();
@@ -177,6 +198,13 @@ protected void mappingWithIndexOptionsPruneFalse(XContentBuilder b) throws IOExc
177198
b.endObject();
178199
}
179200

201+
private void mapping(XContentBuilder b, @Nullable Boolean prune, PruningConfig pruningConfig) throws IOException {
202+
b.field("type", "sparse_vector");
203+
if (prune != null) {
204+
b.field("index_options", new SparseVectorFieldMapper.SparseVectorIndexOptions(prune, pruningConfig.tokenPruningConfig));
205+
}
206+
}
207+
180208
@Override
181209
protected boolean supportsStoredFields() {
182210
return false;
@@ -678,14 +706,58 @@ public void testTokensWeightThresholdCorrect() {
678706
);
679707
}
680708

709+
private enum PruningScenario {
710+
NO_PRUNING, // No pruning applied - all tokens preserved
711+
DEFAULT_PRUNING, // Default pruning configuration
712+
STRICT_PRUNING // Stricter pruning with higher thresholds
713+
}
714+
715+
private enum PruningConfig {
716+
NULL(null),
717+
EXPLICIT_DEFAULT(new TokenPruningConfig()),
718+
STRICT(new TokenPruningConfig(STRICT_TOKENS_FREQ_RATIO_THRESHOLD, STRICT_TOKENS_WEIGHT_THRESHOLD, false));
719+
720+
public final @Nullable TokenPruningConfig tokenPruningConfig;
721+
722+
PruningConfig(@Nullable TokenPruningConfig tokenPruningConfig) {
723+
this.tokenPruningConfig = tokenPruningConfig;
724+
}
725+
}
726+
727+
private final Set<PruningOptions> validIndexPruningScenarios = Set.of(
728+
new PruningOptions(false, PruningConfig.NULL),
729+
new PruningOptions(true, PruningConfig.NULL),
730+
new PruningOptions(true, PruningConfig.EXPLICIT_DEFAULT),
731+
new PruningOptions(true, PruningConfig.STRICT),
732+
new PruningOptions(null, PruningConfig.NULL)
733+
);
734+
735+
private record PruningOptions(@Nullable Boolean prune, PruningConfig pruningConfig) {}
736+
681737
private void withSearchExecutionContext(MapperService mapperService, CheckedConsumer<SearchExecutionContext, IOException> consumer)
682738
throws IOException {
683739
var mapper = mapperService.documentMapper();
684740
try (Directory directory = newDirectory()) {
685741
RandomIndexWriter iw = new RandomIndexWriter(random(), directory);
686-
var sourceToParse = source(this::writeField);
687-
ParsedDocument doc1 = mapper.parse(sourceToParse);
688-
iw.addDocument(doc1.rootDoc());
742+
743+
int commonDocs = 20;
744+
for (int i = 0; i < commonDocs; i++) {
745+
iw.addDocument(mapper.parse(source(b -> b.field("field", COMMON_TOKENS))).rootDoc());
746+
}
747+
748+
int mediumDocs = 5;
749+
for (int i = 0; i < mediumDocs; i++) {
750+
iw.addDocument(mapper.parse(source(b -> b.field("field", MEDIUM_TOKENS))).rootDoc());
751+
}
752+
753+
iw.addDocument(mapper.parse(source(b -> b.field("field", RARE_TOKENS))).rootDoc());
754+
755+
// This will lower the averageTokenFreqRatio so that common tokens get pruned with default settings
756+
Map<String, Float> uniqueDoc = new TreeMap<>();
757+
for (int i = 0; i < 20; i++) {
758+
uniqueDoc.put("unique" + i, 0.5f);
759+
}
760+
iw.addDocument(mapper.parse(source(b -> b.field("field", uniqueDoc))).rootDoc());
689761
iw.close();
690762

691763
try (DirectoryReader reader = wrapInMockESDirectoryReader(DirectoryReader.open(directory))) {
@@ -695,163 +767,97 @@ private void withSearchExecutionContext(MapperService mapperService, CheckedCons
695767
}
696768
}
697769

698-
public void testTypeQueryFinalizationWithRandomOptions() throws Exception {
699-
for (int i = 0; i < 20; i++) {
700-
runTestTypeQueryFinalization(
701-
randomBoolean(), // useIndexVersionBeforeIndexOptions
702-
randomBoolean(), // useMapperDefaultIndexOptions
703-
randomBoolean(), // setMapperIndexOptionsPruneToFalse
704-
randomBoolean(), // queryOverridesPruningConfig
705-
randomBoolean() // queryOverridesPruneToBeFalse
770+
public void testPruningScenarios() throws Exception {
771+
for (int i = 0; i < 120; i++) {
772+
assertPruningScenario(
773+
randomFrom(validIndexPruningScenarios),
774+
new PruningOptions(randomBoolean() ? randomBoolean() : null, randomFrom(PruningConfig.values()))
706775
);
707776
}
708777
}
709778

710-
public void testTypeQueryFinalizationDefaultsCurrentVersion() throws Exception {
711-
IndexVersion version = IndexVersion.current();
712-
MapperService mapperService = createMapperService(version, fieldMapping(this::minimalMapping));
713-
714-
// query should be pruned by default on newer index versions
715-
performTypeQueryFinalizationTest(mapperService, null, null, true);
716-
}
717-
718-
public void testTypeQueryFinalizationDefaultsPreviousVersion() throws Exception {
719-
IndexVersion version = IndexVersionUtils.randomVersionBetween(
720-
random(),
721-
UPGRADE_TO_LUCENE_10_0_0,
722-
IndexVersionUtils.getPreviousVersion(SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT)
723-
);
724-
MapperService mapperService = createMapperService(version, fieldMapping(this::minimalMapping));
725-
726-
// query should _not_ be pruned by default on older index versions
727-
performTypeQueryFinalizationTest(mapperService, null, null, false);
779+
private XContentBuilder getIndexMapping(PruningOptions pruningOptions) throws IOException {
780+
return fieldMapping(b -> mapping(b, pruningOptions.prune, pruningOptions.pruningConfig));
728781
}
729782

730-
public void testTypeQueryFinalizationWithIndexExplicit() throws Exception {
731-
IndexVersion version = IndexVersion.current();
732-
MapperService mapperService = createMapperService(version, fieldMapping(this::minimalMapping));
783+
private void assertQueryContains(List<Query> expectedClauses, Query query) {
784+
SparseVectorQueryWrapper queryWrapper = (SparseVectorQueryWrapper) query;
785+
var termsQuery = queryWrapper.getTermsQuery();
786+
assertNotNull(termsQuery);
787+
var booleanQuery = (BooleanQuery) termsQuery;
733788

734-
// query should be pruned via explicit index options
735-
performTypeQueryFinalizationTest(mapperService, null, null, true);
789+
Collection<Query> shouldClauses = booleanQuery.getClauses(BooleanClause.Occur.SHOULD);
790+
assertThat(shouldClauses, Matchers.containsInAnyOrder(expectedClauses.toArray()));
736791
}
737792

738-
public void testTypeQueryFinalizationWithIndexExplicitDoNotPrune() throws Exception {
739-
IndexVersion version = IndexVersion.current();
740-
MapperService mapperService = createMapperService(version, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
793+
private PruningScenario getEffectivePruningScenario(
794+
PruningOptions indexPruningOptions,
795+
PruningOptions queryPruningOptions,
796+
IndexVersion indexVersion
797+
) {
798+
Boolean shouldPrune = queryPruningOptions.prune;
799+
if (shouldPrune == null) {
800+
shouldPrune = indexPruningOptions.prune;
801+
}
741802

742-
// query should be pruned via explicit index options
743-
performTypeQueryFinalizationTest(mapperService, null, null, false);
744-
}
803+
if (shouldPrune == null) {
804+
shouldPrune = indexVersion.onOrAfter(SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT);
805+
}
745806

746-
public void testTypeQueryFinalizationQueryOverridesPruning() throws Exception {
747-
IndexVersion version = IndexVersion.current();
748-
MapperService mapperService = createMapperService(version, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
807+
PruningScenario pruningScenario = PruningScenario.NO_PRUNING;
808+
if (shouldPrune) {
809+
PruningConfig pruningConfig = queryPruningOptions.pruningConfig;
810+
if (pruningConfig == PruningConfig.NULL) {
811+
pruningConfig = indexPruningOptions.pruningConfig;
812+
}
813+
pruningScenario = switch (pruningConfig) {
814+
case STRICT -> PruningScenario.STRICT_PRUNING;
815+
case EXPLICIT_DEFAULT, NULL -> PruningScenario.DEFAULT_PRUNING;
816+
};
817+
}
749818

750-
// query should still be pruned due to query builder setting it
751-
performTypeQueryFinalizationTest(mapperService, true, new TokenPruningConfig(), true);
819+
return pruningScenario;
752820
}
753821

754-
public void testTypeQueryFinalizationQueryOverridesPruningOff() throws Exception {
755-
IndexVersion version = IndexVersion.current();
756-
MapperService mapperService = createMapperService(version, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
822+
private List<Query> getExpectedQueryClauses(
823+
SparseVectorFieldMapper.SparseVectorFieldType ft,
824+
PruningScenario pruningScenario,
825+
SearchExecutionContext searchExecutionContext
826+
) {
827+
List<WeightedToken> tokens = switch (pruningScenario) {
828+
case NO_PRUNING -> QUERY_VECTORS;
829+
case DEFAULT_PRUNING -> QUERY_VECTORS.stream()
830+
.filter(t -> t.token().startsWith("rare") || t.token().startsWith("medium"))
831+
.toList();
832+
case STRICT_PRUNING -> QUERY_VECTORS.stream().filter(t -> t.token().endsWith("keep_strict")).toList();
833+
};
757834

758-
// query should not pruned due to query builder setting it
759-
performTypeQueryFinalizationTest(mapperService, false, null, false);
835+
return tokens.stream().map(t -> {
836+
Query termQuery = ft.termQuery(t.token(), searchExecutionContext);
837+
return new BoostQuery(termQuery, t.weight());
838+
}).collect(Collectors.toUnmodifiableList());
760839
}
761840

762-
private void performTypeQueryFinalizationTest(
763-
MapperService mapperService,
764-
@Nullable Boolean queryPrune,
765-
@Nullable TokenPruningConfig queryTokenPruningConfig,
766-
boolean queryShouldBePruned
767-
) throws IOException {
841+
private void assertPruningScenario(PruningOptions indexPruningOptions, PruningOptions queryPruningOptions) throws IOException {
842+
IndexVersion indexVersion = getIndexVersionForTest(randomBoolean());
843+
MapperService mapperService = createMapperService(indexVersion, getIndexMapping(indexPruningOptions));
844+
PruningScenario effectivePruningScenario = getEffectivePruningScenario(indexPruningOptions, queryPruningOptions, indexVersion);
768845
withSearchExecutionContext(mapperService, (context) -> {
769846
SparseVectorFieldMapper.SparseVectorFieldType ft = (SparseVectorFieldMapper.SparseVectorFieldType) mapperService.fieldType(
770847
"field"
771848
);
772-
Query finalizedQuery = ft.finalizeSparseVectorQuery(context, "field", QUERY_VECTORS, queryPrune, queryTokenPruningConfig);
773-
774-
if (queryShouldBePruned) {
775-
assertQueryWasPruned(finalizedQuery);
776-
} else {
777-
assertQueryWasNotPruned(finalizedQuery);
778-
}
849+
List<Query> expectedQueryClauses = getExpectedQueryClauses(ft, effectivePruningScenario, context);
850+
Query finalizedQuery = ft.finalizeSparseVectorQuery(
851+
context,
852+
"field",
853+
QUERY_VECTORS,
854+
queryPruningOptions.prune,
855+
queryPruningOptions.pruningConfig.tokenPruningConfig
856+
);
857+
assertQueryContains(expectedQueryClauses, finalizedQuery);
779858
});
780859
}
781860

782-
private void assertQueryWasPruned(Query query) {
783-
assertQueryHasClauseCount(query, 0);
784-
}
785-
786-
private void assertQueryWasNotPruned(Query query) {
787-
assertQueryHasClauseCount(query, QUERY_VECTORS.size());
788-
}
789-
790-
private void assertQueryHasClauseCount(Query query, int clauseCount) {
791-
SparseVectorQueryWrapper queryWrapper = (SparseVectorQueryWrapper) query;
792-
var termsQuery = queryWrapper.getTermsQuery();
793-
assertNotNull(termsQuery);
794-
var booleanQuery = (BooleanQuery) termsQuery;
795-
Collection<Query> clauses = booleanQuery.getClauses(BooleanClause.Occur.SHOULD);
796-
assertThat(clauses.size(), equalTo(clauseCount));
797-
}
798-
799-
/**
800-
* Runs a test of the query finalization based on various parameters
801-
* that provides
802-
* @param useIndexVersionBeforeIndexOptions set to true to use a previous index version before mapper index_options
803-
* @param useMapperDefaultIndexOptions set to false to use an explicit, non-default mapper index_options
804-
* @param setMapperIndexOptionsPruneToFalse set to true to use prune:false in the mapper index_options
805-
* @param queryOverridesPruningConfig set to true to designate the query will provide a pruning_config
806-
* @param queryOverridesPruneToBeFalse if true and queryOverridesPruningConfig is true, the query will provide prune:false
807-
* @throws IOException
808-
*/
809-
private void runTestTypeQueryFinalization(
810-
boolean useIndexVersionBeforeIndexOptions,
811-
boolean useMapperDefaultIndexOptions,
812-
boolean setMapperIndexOptionsPruneToFalse,
813-
boolean queryOverridesPruningConfig,
814-
boolean queryOverridesPruneToBeFalse
815-
) throws IOException {
816-
MapperService mapperService = getMapperServiceForTest(
817-
useIndexVersionBeforeIndexOptions,
818-
useMapperDefaultIndexOptions,
819-
setMapperIndexOptionsPruneToFalse
820-
);
821-
822-
// check and see if the query should explicitly override the index_options
823-
Boolean shouldQueryPrune = queryOverridesPruningConfig ? (queryOverridesPruneToBeFalse == false) : null;
824-
825-
// get the pruning configuration for the query if it's overriding
826-
TokenPruningConfig queryPruningConfig = Boolean.TRUE.equals(shouldQueryPrune) ? new TokenPruningConfig() : null;
827-
828-
// our logic if the results should be pruned or not
829-
// we should _not_ prune if any of the following:
830-
// - the query explicitly overrides the options and `prune` is set to false
831-
// - the query does not override the pruning options and:
832-
// - either we are using a previous index version
833-
// - or the index_options explicitly sets `prune` to false
834-
boolean resultShouldNotBePruned = ((queryOverridesPruningConfig && queryOverridesPruneToBeFalse)
835-
|| (queryOverridesPruningConfig == false && (useIndexVersionBeforeIndexOptions || setMapperIndexOptionsPruneToFalse)));
836-
837-
try {
838-
performTypeQueryFinalizationTest(mapperService, shouldQueryPrune, queryPruningConfig, resultShouldNotBePruned == false);
839-
} catch (AssertionError e) {
840-
String message = "performTypeQueryFinalizationTest failed using parameters: "
841-
+ "useIndexVersionBeforeIndexOptions: "
842-
+ useIndexVersionBeforeIndexOptions
843-
+ ", useMapperDefaultIndexOptions: "
844-
+ useMapperDefaultIndexOptions
845-
+ ", setMapperIndexOptionsPruneToFalse: "
846-
+ setMapperIndexOptionsPruneToFalse
847-
+ ", queryOverridesPruningConfig: "
848-
+ queryOverridesPruningConfig
849-
+ ", queryOverridesPruneToBeFalse: "
850-
+ queryOverridesPruneToBeFalse;
851-
throw new AssertionError(message, e);
852-
}
853-
}
854-
855861
private IndexVersion getIndexVersionForTest(boolean usePreviousIndex) {
856862
return usePreviousIndex
857863
? IndexVersionUtils.randomVersionBetween(
@@ -862,36 +868,10 @@ private IndexVersion getIndexVersionForTest(boolean usePreviousIndex) {
862868
: IndexVersionUtils.randomVersionBetween(random(), SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT, IndexVersion.current());
863869
}
864870

865-
private MapperService getMapperServiceForTest(
866-
boolean usePreviousIndex,
867-
boolean useIndexOptionsDefaults,
868-
boolean explicitIndexOptionsDoNotPrune
869-
) throws IOException {
870-
// get the index version of the test to use
871-
// either a current version that supports index options, or a previous version that does not
872-
IndexVersion indexVersion = getIndexVersionForTest(usePreviousIndex);
873-
874-
// if it's using the old index, we always use the minimal mapping without index_options
875-
if (usePreviousIndex) {
876-
return createMapperService(indexVersion, fieldMapping(this::minimalMapping));
877-
}
878-
879-
// if we set explicitIndexOptionsDoNotPrune, the index_options (if present) will explicitly include "prune: false"
880-
if (explicitIndexOptionsDoNotPrune) {
881-
return createMapperService(indexVersion, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
882-
}
883-
884-
// either return the default (minimal) mapping or one with an explicit pruning_config
885-
return useIndexOptionsDefaults
886-
? createMapperService(indexVersion, fieldMapping(this::minimalMapping))
887-
: createMapperService(indexVersion, fieldMapping(this::minimalMappingWithExplicitIndexOptions));
888-
}
889-
890-
private static List<WeightedToken> QUERY_VECTORS = List.of(
891-
new WeightedToken("pugs", 0.5f),
892-
new WeightedToken("cats", 0.4f),
893-
new WeightedToken("is", 0.1f)
894-
);
871+
private static final List<WeightedToken> QUERY_VECTORS = Stream.of(RARE_TOKENS, MEDIUM_TOKENS, COMMON_TOKENS)
872+
.flatMap(map -> map.entrySet().stream())
873+
.map(entry -> new WeightedToken(entry.getKey(), entry.getValue()))
874+
.collect(Collectors.toList());
895875

896876
/**
897877
* Handles float/double conversion when reading/writing with xcontent by converting all numbers to floats.

0 commit comments

Comments
 (0)