Skip to content

Commit 6cd330c

Browse files
authored
refactor(test): add sparse vector pruning tests (#132264)
1 parent be3eed9 commit 6cd330c

File tree

4 files changed

+192
-176
lines changed

4 files changed

+192
-176
lines changed

server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java

Lines changed: 146 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.lucene.index.LeafReader;
1818
import org.apache.lucene.search.BooleanClause;
1919
import org.apache.lucene.search.BooleanQuery;
20+
import org.apache.lucene.search.BoostQuery;
2021
import org.apache.lucene.search.IndexSearcher;
2122
import org.apache.lucene.search.Query;
2223
import org.apache.lucene.store.Directory;
@@ -54,7 +55,10 @@
5455
import java.util.LinkedHashMap;
5556
import java.util.List;
5657
import java.util.Map;
58+
import java.util.Set;
5759
import java.util.TreeMap;
60+
import java.util.stream.Collectors;
61+
import java.util.stream.Stream;
5862

5963
import static org.elasticsearch.index.IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT;
6064
import static org.elasticsearch.index.IndexVersions.UPGRADE_TO_LUCENE_10_0_0;
@@ -69,6 +73,22 @@
6973

7074
public class SparseVectorFieldMapperTests extends SyntheticVectorsMapperTestCase {
7175

76+
public static final float STRICT_TOKENS_WEIGHT_THRESHOLD = 0.5f;
77+
public static final float STRICT_TOKENS_FREQ_RATIO_THRESHOLD = 1;
78+
79+
private static final Map<String, Float> COMMON_TOKENS = Map.of(
80+
"common1_drop_default",
81+
0.1f,
82+
"common2_drop_default",
83+
0.1f,
84+
"common3_drop_default",
85+
0.1f
86+
);
87+
88+
private static final Map<String, Float> MEDIUM_TOKENS = Map.of("medium1_keep_strict", 0.5f, "medium2_keep_default", 0.25f);
89+
90+
private static final Map<String, Float> RARE_TOKENS = Map.of("rare1_keep_strict", 0.9f, "rare2_keep_strict", 0.85f);
91+
7292
@Override
7393
protected Object getSampleValueForDocument() {
7494
return new TreeMap<>(
@@ -123,7 +143,7 @@ protected void minimalMappingWithExplicitIndexOptions(XContentBuilder b) throws
123143
b.field("prune", true);
124144
b.startObject("pruning_config");
125145
{
126-
b.field("tokens_freq_ratio_threshold", 3.0f);
146+
b.field("tokens_freq_ratio_threshold", 1.0f);
127147
b.field("tokens_weight_threshold", 0.5f);
128148
}
129149
b.endObject();
@@ -178,6 +198,13 @@ protected void mappingWithIndexOptionsPruneFalse(XContentBuilder b) throws IOExc
178198
b.endObject();
179199
}
180200

201+
private void mapping(XContentBuilder b, @Nullable Boolean prune, PruningConfig pruningConfig) throws IOException {
202+
b.field("type", "sparse_vector");
203+
if (prune != null) {
204+
b.field("index_options", new SparseVectorFieldMapper.SparseVectorIndexOptions(prune, pruningConfig.tokenPruningConfig));
205+
}
206+
}
207+
181208
@Override
182209
protected boolean supportsStoredFields() {
183210
return false;
@@ -676,14 +703,58 @@ public void testTokensWeightThresholdCorrect() {
676703
);
677704
}
678705

706+
private enum PruningScenario {
707+
NO_PRUNING, // No pruning applied - all tokens preserved
708+
DEFAULT_PRUNING, // Default pruning configuration
709+
STRICT_PRUNING // Stricter pruning with higher thresholds
710+
}
711+
712+
private enum PruningConfig {
713+
NULL(null),
714+
EXPLICIT_DEFAULT(new TokenPruningConfig()),
715+
STRICT(new TokenPruningConfig(STRICT_TOKENS_FREQ_RATIO_THRESHOLD, STRICT_TOKENS_WEIGHT_THRESHOLD, false));
716+
717+
public final @Nullable TokenPruningConfig tokenPruningConfig;
718+
719+
PruningConfig(@Nullable TokenPruningConfig tokenPruningConfig) {
720+
this.tokenPruningConfig = tokenPruningConfig;
721+
}
722+
}
723+
724+
private final Set<PruningOptions> validIndexPruningScenarios = Set.of(
725+
new PruningOptions(false, PruningConfig.NULL),
726+
new PruningOptions(true, PruningConfig.NULL),
727+
new PruningOptions(true, PruningConfig.EXPLICIT_DEFAULT),
728+
new PruningOptions(true, PruningConfig.STRICT),
729+
new PruningOptions(null, PruningConfig.NULL)
730+
);
731+
732+
private record PruningOptions(@Nullable Boolean prune, PruningConfig pruningConfig) {}
733+
679734
private void withSearchExecutionContext(MapperService mapperService, CheckedConsumer<SearchExecutionContext, IOException> consumer)
680735
throws IOException {
681736
var mapper = mapperService.documentMapper();
682737
try (Directory directory = newDirectory()) {
683738
RandomIndexWriter iw = new RandomIndexWriter(random(), directory);
684-
var sourceToParse = source(this::writeField);
685-
ParsedDocument doc1 = mapper.parse(sourceToParse);
686-
iw.addDocument(doc1.rootDoc());
739+
740+
int commonDocs = 20;
741+
for (int i = 0; i < commonDocs; i++) {
742+
iw.addDocument(mapper.parse(source(b -> b.field("field", COMMON_TOKENS))).rootDoc());
743+
}
744+
745+
int mediumDocs = 5;
746+
for (int i = 0; i < mediumDocs; i++) {
747+
iw.addDocument(mapper.parse(source(b -> b.field("field", MEDIUM_TOKENS))).rootDoc());
748+
}
749+
750+
iw.addDocument(mapper.parse(source(b -> b.field("field", RARE_TOKENS))).rootDoc());
751+
752+
// This will lower the averageTokenFreqRatio so that common tokens get pruned with default settings
753+
Map<String, Float> uniqueDoc = new TreeMap<>();
754+
for (int i = 0; i < 20; i++) {
755+
uniqueDoc.put("unique" + i, 0.5f);
756+
}
757+
iw.addDocument(mapper.parse(source(b -> b.field("field", uniqueDoc))).rootDoc());
687758
iw.close();
688759

689760
try (DirectoryReader reader = wrapInMockESDirectoryReader(DirectoryReader.open(directory))) {
@@ -693,163 +764,97 @@ private void withSearchExecutionContext(MapperService mapperService, CheckedCons
693764
}
694765
}
695766

696-
public void testTypeQueryFinalizationWithRandomOptions() throws Exception {
697-
for (int i = 0; i < 20; i++) {
698-
runTestTypeQueryFinalization(
699-
randomBoolean(), // useIndexVersionBeforeIndexOptions
700-
randomBoolean(), // useMapperDefaultIndexOptions
701-
randomBoolean(), // setMapperIndexOptionsPruneToFalse
702-
randomBoolean(), // queryOverridesPruningConfig
703-
randomBoolean() // queryOverridesPruneToBeFalse
767+
public void testPruningScenarios() throws Exception {
768+
for (int i = 0; i < 120; i++) {
769+
assertPruningScenario(
770+
randomFrom(validIndexPruningScenarios),
771+
new PruningOptions(randomBoolean() ? randomBoolean() : null, randomFrom(PruningConfig.values()))
704772
);
705773
}
706774
}
707775

708-
public void testTypeQueryFinalizationDefaultsCurrentVersion() throws Exception {
709-
IndexVersion version = IndexVersion.current();
710-
MapperService mapperService = createMapperService(version, fieldMapping(this::minimalMapping));
711-
712-
// query should be pruned by default on newer index versions
713-
performTypeQueryFinalizationTest(mapperService, null, null, true);
714-
}
715-
716-
public void testTypeQueryFinalizationDefaultsPreviousVersion() throws Exception {
717-
IndexVersion version = IndexVersionUtils.randomVersionBetween(
718-
random(),
719-
UPGRADE_TO_LUCENE_10_0_0,
720-
IndexVersionUtils.getPreviousVersion(SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT)
721-
);
722-
MapperService mapperService = createMapperService(version, fieldMapping(this::minimalMapping));
723-
724-
// query should _not_ be pruned by default on older index versions
725-
performTypeQueryFinalizationTest(mapperService, null, null, false);
776+
private XContentBuilder getIndexMapping(PruningOptions pruningOptions) throws IOException {
777+
return fieldMapping(b -> mapping(b, pruningOptions.prune, pruningOptions.pruningConfig));
726778
}
727779

728-
public void testTypeQueryFinalizationWithIndexExplicit() throws Exception {
729-
IndexVersion version = IndexVersion.current();
730-
MapperService mapperService = createMapperService(version, fieldMapping(this::minimalMapping));
780+
private void assertQueryContains(List<Query> expectedClauses, Query query) {
781+
SparseVectorQueryWrapper queryWrapper = (SparseVectorQueryWrapper) query;
782+
var termsQuery = queryWrapper.getTermsQuery();
783+
assertNotNull(termsQuery);
784+
var booleanQuery = (BooleanQuery) termsQuery;
731785

732-
// query should be pruned via explicit index options
733-
performTypeQueryFinalizationTest(mapperService, null, null, true);
786+
Collection<Query> shouldClauses = booleanQuery.getClauses(BooleanClause.Occur.SHOULD);
787+
assertThat(shouldClauses, Matchers.containsInAnyOrder(expectedClauses.toArray()));
734788
}
735789

736-
public void testTypeQueryFinalizationWithIndexExplicitDoNotPrune() throws Exception {
737-
IndexVersion version = IndexVersion.current();
738-
MapperService mapperService = createMapperService(version, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
790+
private PruningScenario getEffectivePruningScenario(
791+
PruningOptions indexPruningOptions,
792+
PruningOptions queryPruningOptions,
793+
IndexVersion indexVersion
794+
) {
795+
Boolean shouldPrune = queryPruningOptions.prune;
796+
if (shouldPrune == null) {
797+
shouldPrune = indexPruningOptions.prune;
798+
}
739799

740-
// query should be pruned via explicit index options
741-
performTypeQueryFinalizationTest(mapperService, null, null, false);
742-
}
800+
if (shouldPrune == null) {
801+
shouldPrune = indexVersion.onOrAfter(SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT);
802+
}
743803

744-
public void testTypeQueryFinalizationQueryOverridesPruning() throws Exception {
745-
IndexVersion version = IndexVersion.current();
746-
MapperService mapperService = createMapperService(version, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
804+
PruningScenario pruningScenario = PruningScenario.NO_PRUNING;
805+
if (shouldPrune) {
806+
PruningConfig pruningConfig = queryPruningOptions.pruningConfig;
807+
if (pruningConfig == PruningConfig.NULL) {
808+
pruningConfig = indexPruningOptions.pruningConfig;
809+
}
810+
pruningScenario = switch (pruningConfig) {
811+
case STRICT -> PruningScenario.STRICT_PRUNING;
812+
case EXPLICIT_DEFAULT, NULL -> PruningScenario.DEFAULT_PRUNING;
813+
};
814+
}
747815

748-
// query should still be pruned due to query builder setting it
749-
performTypeQueryFinalizationTest(mapperService, true, new TokenPruningConfig(), true);
816+
return pruningScenario;
750817
}
751818

752-
public void testTypeQueryFinalizationQueryOverridesPruningOff() throws Exception {
753-
IndexVersion version = IndexVersion.current();
754-
MapperService mapperService = createMapperService(version, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
819+
private List<Query> getExpectedQueryClauses(
820+
SparseVectorFieldMapper.SparseVectorFieldType ft,
821+
PruningScenario pruningScenario,
822+
SearchExecutionContext searchExecutionContext
823+
) {
824+
List<WeightedToken> tokens = switch (pruningScenario) {
825+
case NO_PRUNING -> QUERY_VECTORS;
826+
case DEFAULT_PRUNING -> QUERY_VECTORS.stream()
827+
.filter(t -> t.token().startsWith("rare") || t.token().startsWith("medium"))
828+
.toList();
829+
case STRICT_PRUNING -> QUERY_VECTORS.stream().filter(t -> t.token().endsWith("keep_strict")).toList();
830+
};
755831

756-
// query should not pruned due to query builder setting it
757-
performTypeQueryFinalizationTest(mapperService, false, null, false);
832+
return tokens.stream().map(t -> {
833+
Query termQuery = ft.termQuery(t.token(), searchExecutionContext);
834+
return new BoostQuery(termQuery, t.weight());
835+
}).collect(Collectors.toUnmodifiableList());
758836
}
759837

760-
private void performTypeQueryFinalizationTest(
761-
MapperService mapperService,
762-
@Nullable Boolean queryPrune,
763-
@Nullable TokenPruningConfig queryTokenPruningConfig,
764-
boolean queryShouldBePruned
765-
) throws IOException {
838+
private void assertPruningScenario(PruningOptions indexPruningOptions, PruningOptions queryPruningOptions) throws IOException {
839+
IndexVersion indexVersion = getIndexVersionForTest(randomBoolean());
840+
MapperService mapperService = createMapperService(indexVersion, getIndexMapping(indexPruningOptions));
841+
PruningScenario effectivePruningScenario = getEffectivePruningScenario(indexPruningOptions, queryPruningOptions, indexVersion);
766842
withSearchExecutionContext(mapperService, (context) -> {
767843
SparseVectorFieldMapper.SparseVectorFieldType ft = (SparseVectorFieldMapper.SparseVectorFieldType) mapperService.fieldType(
768844
"field"
769845
);
770-
Query finalizedQuery = ft.finalizeSparseVectorQuery(context, "field", QUERY_VECTORS, queryPrune, queryTokenPruningConfig);
771-
772-
if (queryShouldBePruned) {
773-
assertQueryWasPruned(finalizedQuery);
774-
} else {
775-
assertQueryWasNotPruned(finalizedQuery);
776-
}
846+
List<Query> expectedQueryClauses = getExpectedQueryClauses(ft, effectivePruningScenario, context);
847+
Query finalizedQuery = ft.finalizeSparseVectorQuery(
848+
context,
849+
"field",
850+
QUERY_VECTORS,
851+
queryPruningOptions.prune,
852+
queryPruningOptions.pruningConfig.tokenPruningConfig
853+
);
854+
assertQueryContains(expectedQueryClauses, finalizedQuery);
777855
});
778856
}
779857

780-
private void assertQueryWasPruned(Query query) {
781-
assertQueryHasClauseCount(query, 0);
782-
}
783-
784-
private void assertQueryWasNotPruned(Query query) {
785-
assertQueryHasClauseCount(query, QUERY_VECTORS.size());
786-
}
787-
788-
private void assertQueryHasClauseCount(Query query, int clauseCount) {
789-
SparseVectorQueryWrapper queryWrapper = (SparseVectorQueryWrapper) query;
790-
var termsQuery = queryWrapper.getTermsQuery();
791-
assertNotNull(termsQuery);
792-
var booleanQuery = (BooleanQuery) termsQuery;
793-
Collection<Query> clauses = booleanQuery.getClauses(BooleanClause.Occur.SHOULD);
794-
assertThat(clauses.size(), equalTo(clauseCount));
795-
}
796-
797-
/**
798-
* Runs a test of the query finalization based on various parameters
799-
* that provides
800-
* @param useIndexVersionBeforeIndexOptions set to true to use a previous index version before mapper index_options
801-
* @param useMapperDefaultIndexOptions set to false to use an explicit, non-default mapper index_options
802-
* @param setMapperIndexOptionsPruneToFalse set to true to use prune:false in the mapper index_options
803-
* @param queryOverridesPruningConfig set to true to designate the query will provide a pruning_config
804-
* @param queryOverridesPruneToBeFalse if true and queryOverridesPruningConfig is true, the query will provide prune:false
805-
* @throws IOException
806-
*/
807-
private void runTestTypeQueryFinalization(
808-
boolean useIndexVersionBeforeIndexOptions,
809-
boolean useMapperDefaultIndexOptions,
810-
boolean setMapperIndexOptionsPruneToFalse,
811-
boolean queryOverridesPruningConfig,
812-
boolean queryOverridesPruneToBeFalse
813-
) throws IOException {
814-
MapperService mapperService = getMapperServiceForTest(
815-
useIndexVersionBeforeIndexOptions,
816-
useMapperDefaultIndexOptions,
817-
setMapperIndexOptionsPruneToFalse
818-
);
819-
820-
// check and see if the query should explicitly override the index_options
821-
Boolean shouldQueryPrune = queryOverridesPruningConfig ? (queryOverridesPruneToBeFalse == false) : null;
822-
823-
// get the pruning configuration for the query if it's overriding
824-
TokenPruningConfig queryPruningConfig = Boolean.TRUE.equals(shouldQueryPrune) ? new TokenPruningConfig() : null;
825-
826-
// our logic if the results should be pruned or not
827-
// we should _not_ prune if any of the following:
828-
// - the query explicitly overrides the options and `prune` is set to false
829-
// - the query does not override the pruning options and:
830-
// - either we are using a previous index version
831-
// - or the index_options explicitly sets `prune` to false
832-
boolean resultShouldNotBePruned = ((queryOverridesPruningConfig && queryOverridesPruneToBeFalse)
833-
|| (queryOverridesPruningConfig == false && (useIndexVersionBeforeIndexOptions || setMapperIndexOptionsPruneToFalse)));
834-
835-
try {
836-
performTypeQueryFinalizationTest(mapperService, shouldQueryPrune, queryPruningConfig, resultShouldNotBePruned == false);
837-
} catch (AssertionError e) {
838-
String message = "performTypeQueryFinalizationTest failed using parameters: "
839-
+ "useIndexVersionBeforeIndexOptions: "
840-
+ useIndexVersionBeforeIndexOptions
841-
+ ", useMapperDefaultIndexOptions: "
842-
+ useMapperDefaultIndexOptions
843-
+ ", setMapperIndexOptionsPruneToFalse: "
844-
+ setMapperIndexOptionsPruneToFalse
845-
+ ", queryOverridesPruningConfig: "
846-
+ queryOverridesPruningConfig
847-
+ ", queryOverridesPruneToBeFalse: "
848-
+ queryOverridesPruneToBeFalse;
849-
throw new AssertionError(message, e);
850-
}
851-
}
852-
853858
private IndexVersion getIndexVersionForTest(boolean usePreviousIndex) {
854859
return usePreviousIndex
855860
? IndexVersionUtils.randomVersionBetween(
@@ -860,36 +865,10 @@ private IndexVersion getIndexVersionForTest(boolean usePreviousIndex) {
860865
: IndexVersionUtils.randomVersionBetween(random(), SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT, IndexVersion.current());
861866
}
862867

863-
private MapperService getMapperServiceForTest(
864-
boolean usePreviousIndex,
865-
boolean useIndexOptionsDefaults,
866-
boolean explicitIndexOptionsDoNotPrune
867-
) throws IOException {
868-
// get the index version of the test to use
869-
// either a current version that supports index options, or a previous version that does not
870-
IndexVersion indexVersion = getIndexVersionForTest(usePreviousIndex);
871-
872-
// if it's using the old index, we always use the minimal mapping without index_options
873-
if (usePreviousIndex) {
874-
return createMapperService(indexVersion, fieldMapping(this::minimalMapping));
875-
}
876-
877-
// if we set explicitIndexOptionsDoNotPrune, the index_options (if present) will explicitly include "prune: false"
878-
if (explicitIndexOptionsDoNotPrune) {
879-
return createMapperService(indexVersion, fieldMapping(this::mappingWithIndexOptionsPruneFalse));
880-
}
881-
882-
// either return the default (minimal) mapping or one with an explicit pruning_config
883-
return useIndexOptionsDefaults
884-
? createMapperService(indexVersion, fieldMapping(this::minimalMapping))
885-
: createMapperService(indexVersion, fieldMapping(this::minimalMappingWithExplicitIndexOptions));
886-
}
887-
888-
private static List<WeightedToken> QUERY_VECTORS = List.of(
889-
new WeightedToken("pugs", 0.5f),
890-
new WeightedToken("cats", 0.4f),
891-
new WeightedToken("is", 0.1f)
892-
);
868+
private static final List<WeightedToken> QUERY_VECTORS = Stream.of(RARE_TOKENS, MEDIUM_TOKENS, COMMON_TOKENS)
869+
.flatMap(map -> map.entrySet().stream())
870+
.map(entry -> new WeightedToken(entry.getKey(), entry.getValue()))
871+
.collect(Collectors.toList());
893872

894873
/**
895874
* Handles float/double conversion when reading/writing with xcontent by converting all numbers to floats.

0 commit comments

Comments
 (0)