Skip to content

Commit d37bad0

Browse files
committed
address PR comments and minor refactoring
1 parent a79f56f commit d37bad0

File tree

1 file changed

+61
-55
lines changed

1 file changed

+61
-55
lines changed

server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.apache.lucene.index.LeafReader;
1818
import org.apache.lucene.search.BooleanClause;
1919
import org.apache.lucene.search.BooleanQuery;
20+
import org.apache.lucene.search.BoostQuery;
2021
import org.apache.lucene.search.IndexSearcher;
2122
import org.apache.lucene.search.Query;
2223
import org.apache.lucene.store.Directory;
@@ -39,6 +40,7 @@
3940
import org.elasticsearch.search.lookup.Source;
4041
import org.elasticsearch.search.vectors.SparseVectorQueryWrapper;
4142
import org.elasticsearch.test.index.IndexVersionUtils;
43+
import org.elasticsearch.test.junit.annotations.TestLogging;
4244
import org.elasticsearch.xcontent.ToXContent;
4345
import org.elasticsearch.xcontent.XContentBuilder;
4446
import org.elasticsearch.xcontent.XContentParseException;
@@ -48,6 +50,7 @@
4850
import org.junit.AssumptionViolatedException;
4951

5052
import java.io.IOException;
53+
import java.util.ArrayList;
5154
import java.util.Arrays;
5255
import java.util.Collection;
5356
import java.util.EnumSet;
@@ -56,6 +59,8 @@
5659
import java.util.Map;
5760
import java.util.Set;
5861
import java.util.TreeMap;
62+
import java.util.stream.Collectors;
63+
import java.util.stream.Stream;
5964

6065
import static org.elasticsearch.index.IndexVersions.SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT;
6166
import static org.elasticsearch.index.IndexVersions.UPGRADE_TO_LUCENE_10_0_0;
@@ -73,25 +78,27 @@ public class SparseVectorFieldMapperTests extends SyntheticVectorsMapperTestCase
7378
public static final float STRICT_TOKENS_WEIGHT_THRESHOLD = 0.5f;
7479
public static final float STRICT_TOKENS_FREQ_RATIO_THRESHOLD = 1;
7580

76-
@Override
77-
protected Object getSampleValueForDocument() {
78-
// randomMap(1, 5, () -> Tuple.tuple(randomAlphaOfLengthBetween(5, 10), Float.valueOf(randomIntBetween(1, 127))))
79-
Map<String, Float> map = new TreeMap<>();
80-
81-
// High weight tokens - low freq (should survive strict pruning)
82-
map.put("rare1", (float) randomIntBetween(1, 127));
83-
map.put("rare2", (float) randomIntBetween(1, 127));
81+
private static final Map<String, Float> COMMON_TOKENS = Map.of(
82+
"common1_drop_default", 0.1f,
83+
"common2_drop_default", 0.1f,
84+
"common3_drop_default", 0.1f
85+
);
8486

85-
// Medium weight - medium freq (half of them should survive strict pruning)
86-
map.put("medium_freq", (float) randomIntBetween(1, 127));
87-
map.put("medium_freq2", (float) randomIntBetween(1, 127));
87+
private static final Map<String, Float> MEDIUM_TOKENS = Map.of(
88+
"medium1_keep_strict", 0.5f,
89+
"medium2_keep_default", 0.25f
90+
);
8891

89-
// Low weight tokens - high freq (pruned under default or strict pruning)
90-
map.put("common1", (float) randomIntBetween(1, 127));
91-
map.put("common2", (float) randomIntBetween(1, 127));
92-
map.put("common3", (float) randomIntBetween(1, 127));
92+
private static final Map<String, Float> RARE_TOKENS = Map.of(
93+
"rare1_keep_strict", 0.9f,
94+
"rare2_keep_strict", 0.85f
95+
);
9396

94-
return map;
97+
@Override
98+
protected Object getSampleValueForDocument() {
99+
return new TreeMap<>(
100+
randomMap(1, 5, () -> Tuple.tuple(randomAlphaOfLengthBetween(5, 10), Float.valueOf(randomIntBetween(1, 127))))
101+
);
95102
}
96103

97104
@Override
@@ -722,30 +729,17 @@ private void withSearchExecutionContext(MapperService mapperService, CheckedCons
722729
try (Directory directory = newDirectory()) {
723730
RandomIndexWriter iw = new RandomIndexWriter(random(), directory);
724731

725-
Map<String, Float> commonTokens = new TreeMap<>();
726-
commonTokens.put("common1", 0.1f);
727-
commonTokens.put("common2", 0.1f);
728-
commonTokens.put("common3", 0.1f);
729-
730-
Map<String, Float> mediumTokens = new TreeMap<>();
731-
mediumTokens.put("medium1", 0.5f);
732-
mediumTokens.put("medium2", 0.25f);
733-
734-
Map<String, Float> rareTokens = new TreeMap<>();
735-
rareTokens.put("rare1", 0.9f);
736-
rareTokens.put("rare2", 0.85f);
737-
738732
int commonDocs = 20;
739733
for (int i = 0; i < commonDocs; i++) {
740-
iw.addDocument(mapper.parse(source(b -> b.field("field", commonTokens))).rootDoc());
734+
iw.addDocument(mapper.parse(source(b -> b.field("field", COMMON_TOKENS))).rootDoc());
741735
}
742736

743737
int mediumDocs = 5;
744738
for (int i = 0; i < mediumDocs; i++) {
745-
iw.addDocument(mapper.parse(source(b -> b.field("field", mediumTokens))).rootDoc());
739+
iw.addDocument(mapper.parse(source(b -> b.field("field", MEDIUM_TOKENS))).rootDoc());
746740
}
747741

748-
iw.addDocument(mapper.parse(source(b -> b.field("field", rareTokens))).rootDoc());
742+
iw.addDocument(mapper.parse(source(b -> b.field("field", RARE_TOKENS))).rootDoc());
749743

750744
// This will lower the averageTokenFreqRatio so that common tokens get pruned with default settings
751745
Map<String, Float> uniqueDoc = new TreeMap<>();
@@ -762,13 +756,13 @@ private void withSearchExecutionContext(MapperService mapperService, CheckedCons
762756
}
763757
}
764758

765-
public void testTypeQueryFinalizationPruningScenarios() throws Exception {
759+
public void testPruningScenarios() throws Exception {
766760
for (int i = 0; i < 60; i++) {
767-
runTestTypeQueryFinalization(randomFrom(IndexPruningScenario.values()), randomFrom(QueryPruningScenario.values()));
761+
assertPruningScenario(randomFrom(IndexPruningScenario.values()), randomFrom(QueryPruningScenario.values()));
768762
}
769763
}
770764

771-
public void testTypeQueryFinalizationDefaultsPreviousVersion() throws Exception {
765+
public void testPruningDefaultsPreIndexOptions() throws Exception {
772766
IndexVersion version = IndexVersionUtils.randomVersionBetween(
773767
random(),
774768
UPGRADE_TO_LUCENE_10_0_0,
@@ -791,7 +785,8 @@ public void testTypeQueryFinalizationDefaultsPreviousVersion() throws Exception
791785
queryPruneConfig.v2()
792786
);
793787
// query should _not_ be pruned by default on older index versions
794-
assertQueryWasPruned(finalizedQuery, PruningScenario.NO_PRUNING);
788+
List<Query> expectedQueryClauses = getExpectedQueryClauses(ft, PruningScenario.NO_PRUNING, context);
789+
assertQueryContains(expectedQueryClauses, finalizedQuery);
795790
});
796791

797792
}
@@ -805,12 +800,14 @@ private XContentBuilder getIndexMapping(IndexPruningScenario pruningScenario) th
805800
};
806801
}
807802

808-
private void assertQueryWasPruned(Query query, PruningScenario pruningScenario) {
809-
switch (pruningScenario) {
810-
case NO_PRUNING -> assertQueryHasClauseCount(query, QUERY_VECTORS.size());
811-
case DEFAULT_PRUNING -> assertQueryHasClauseCount(query, QUERY_VECTORS.size() - 3); // 3 common tokens pruned
812-
case STRICT_PRUNING -> assertQueryHasClauseCount(query, QUERY_VECTORS.size() - 4); // 3 common and 1 medium tokens pruned
813-
}
803+
private void assertQueryContains(List<Query> expectedClauses, Query query) {
804+
SparseVectorQueryWrapper queryWrapper = (SparseVectorQueryWrapper) query;
805+
var termsQuery = queryWrapper.getTermsQuery();
806+
assertNotNull(termsQuery);
807+
var booleanQuery = (BooleanQuery) termsQuery;
808+
809+
Collection<Query> shouldClauses = booleanQuery.getClauses(BooleanClause.Occur.SHOULD);
810+
assertThat(shouldClauses, Matchers.containsInAnyOrder(expectedClauses.toArray()));
814811
}
815812

816813
private void assertQueryHasClauseCount(Query query, int clauseCount) {
@@ -880,7 +877,22 @@ private Tuple<Boolean, TokenPruningConfig> getQueryPruneConfig(QueryPruningScena
880877
};
881878
}
882879

883-
private void runTestTypeQueryFinalization(IndexPruningScenario indexPruningScenario, QueryPruningScenario queryPruningScenario)
880+
private List<Query> getExpectedQueryClauses(SparseVectorFieldMapper.SparseVectorFieldType ft, PruningScenario pruningScenario, SearchExecutionContext searchExecutionContext) {
881+
List<WeightedToken> tokens = switch (pruningScenario) {
882+
case NO_PRUNING -> QUERY_VECTORS;
883+
case DEFAULT_PRUNING -> QUERY_VECTORS.stream()
884+
.filter(t -> t.token().startsWith("rare") || t.token().startsWith("medium")).toList();
885+
case STRICT_PRUNING -> QUERY_VECTORS.stream()
886+
.filter(t -> t.token().endsWith("keep_strict")).toList();
887+
};
888+
889+
return tokens.stream().map(t -> {
890+
Query termQuery = ft.termQuery(t.token(), searchExecutionContext);
891+
return new BoostQuery(termQuery, t.weight());
892+
}).collect(Collectors.toUnmodifiableList());
893+
}
894+
895+
private void assertPruningScenario(IndexPruningScenario indexPruningScenario, QueryPruningScenario queryPruningScenario)
884896
throws IOException {
885897
logger.debug("Running test with indexPruningScenario: {}, queryPruningScenario: {}", indexPruningScenario, queryPruningScenario);
886898
IndexVersion indexVersion = IndexVersionUtils.randomVersionBetween(
@@ -896,14 +908,15 @@ private void runTestTypeQueryFinalization(IndexPruningScenario indexPruningScena
896908
SparseVectorFieldMapper.SparseVectorFieldType ft = (SparseVectorFieldMapper.SparseVectorFieldType) mapperService.fieldType(
897909
"field"
898910
);
911+
List<Query> expectedQueryClauses = getExpectedQueryClauses(ft, effectivePruningScenario, context);
899912
Query finalizedQuery = ft.finalizeSparseVectorQuery(
900913
context,
901914
"field",
902915
QUERY_VECTORS,
903916
queryPruneConfig.v1(),
904917
queryPruneConfig.v2()
905918
);
906-
assertQueryWasPruned(finalizedQuery, effectivePruningScenario);
919+
assertQueryContains(expectedQueryClauses, finalizedQuery);
907920
});
908921
}
909922

@@ -917,17 +930,10 @@ private IndexVersion getIndexVersionForTest(boolean usePreviousIndex) {
917930
: IndexVersionUtils.randomVersionBetween(random(), SPARSE_VECTOR_PRUNING_INDEX_OPTIONS_SUPPORT, IndexVersion.current());
918931
}
919932

920-
private static List<WeightedToken> QUERY_VECTORS = List.of(
921-
new WeightedToken("rare1", 0.9f),
922-
new WeightedToken("rare2", 0.85f),
923-
924-
new WeightedToken("medium1", 0.5f), // this will survive strict pruning, due to higher weight
925-
new WeightedToken("medium2", 0.25f),
926-
927-
new WeightedToken("common1", 0.2f),
928-
new WeightedToken("common2", 0.15f),
929-
new WeightedToken("common3", 0.1f)
930-
);
933+
private static final List<WeightedToken> QUERY_VECTORS = Stream.of(RARE_TOKENS, MEDIUM_TOKENS, COMMON_TOKENS)
934+
.flatMap(map -> map.entrySet().stream())
935+
.map(entry -> new WeightedToken(entry.getKey(), entry.getValue()))
936+
.collect(Collectors.toList());
931937

932938
/**
933939
* Handles float/double conversion when reading/writing with xcontent by converting all numbers to floats.

0 commit comments

Comments
 (0)