Skip to content

Commit 512e02a

Browse files
authored
SOLR-17815: Add parameter to regulate for ACORN-based filtering in vector search (#3680)
1 parent aeb9063 commit 512e02a

File tree

6 files changed

+945
-6
lines changed

6 files changed

+945
-6
lines changed

solr/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ New Features
8585

8686
* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti)
8787

88+
* SOLR-17815: Add parameter to regulate for ACORN-based filtering in vector search. (Anna Ruggero, Alessandro Benedetti)
89+
8890
Improvements
8991
---------------------
9092

solr/core/src/java/org/apache/solr/schema/DenseVectorField.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.apache.lucene.search.Query;
4444
import org.apache.lucene.search.SeededKnnVectorQuery;
4545
import org.apache.lucene.search.SortField;
46+
import org.apache.lucene.search.knn.KnnSearchStrategy;
4647
import org.apache.lucene.util.BytesRef;
4748
import org.apache.lucene.util.hnsw.HnswGraph;
4849
import org.apache.solr.common.SolrException;
@@ -379,17 +380,36 @@ public Query getKnnVectorQuery(
379380
int topK,
380381
Query filterQuery,
381382
Query seedQuery,
382-
EarlyTerminationParams earlyTermination) {
383+
EarlyTerminationParams earlyTermination,
384+
Integer filteredSearchThreshold) {
383385

384386
DenseVectorParser vectorBuilder =
385387
getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY);
386388

387389
final Query knnQuery =
388390
switch (vectorEncoding) {
389-
case FLOAT32 -> new KnnFloatVectorQuery(
390-
fieldName, vectorBuilder.getFloatVector(), topK, filterQuery);
391-
case BYTE -> new KnnByteVectorQuery(
392-
fieldName, vectorBuilder.getByteVector(), topK, filterQuery);
391+
case FLOAT32 -> {
392+
if (filteredSearchThreshold != null) {
393+
KnnSearchStrategy knnSearchStrategy =
394+
new KnnSearchStrategy.Hnsw(filteredSearchThreshold);
395+
yield new KnnFloatVectorQuery(
396+
fieldName, vectorBuilder.getFloatVector(), topK, filterQuery, knnSearchStrategy);
397+
} else {
398+
yield new KnnFloatVectorQuery(
399+
fieldName, vectorBuilder.getFloatVector(), topK, filterQuery);
400+
}
401+
}
402+
case BYTE -> {
403+
if (filteredSearchThreshold != null) {
404+
KnnSearchStrategy knnSearchStrategy =
405+
new KnnSearchStrategy.Hnsw(filteredSearchThreshold);
406+
yield new KnnByteVectorQuery(
407+
fieldName, vectorBuilder.getByteVector(), topK, filterQuery, knnSearchStrategy);
408+
} else {
409+
yield new KnnByteVectorQuery(
410+
fieldName, vectorBuilder.getByteVector(), topK, filterQuery);
411+
}
412+
}
393413
};
394414

395415
final boolean seedEnabled = (seedQuery != null);

solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public class KnnQParser extends AbstractVectorQParserBase {
3232
protected static final String TOP_K = "topK";
3333
protected static final int DEFAULT_TOP_K = 10;
3434
protected static final String SEED_QUERY = "seedQuery";
35+
protected static final String FILTERED_SEARCH_THRESHOLD = "filteredSearchThreshold";
3536

3637
// parameters for PatienceKnnVectorQuery, a version of knn vector query that exits early when HNSW
3738
// queue saturates over a {@code #saturationThreshold} for more than {@code #patience} times.
@@ -107,13 +108,15 @@ public Query parse() throws SyntaxError {
107108
final DenseVectorField denseVectorType = getCheckedFieldType(schemaField);
108109
final String vectorToSearch = getVectorToSearch();
109110
final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K);
111+
final Integer filteredSearchThreshold = localParams.getInt(FILTERED_SEARCH_THRESHOLD);
110112

111113
return denseVectorType.getKnnVectorQuery(
112114
schemaField.getName(),
113115
vectorToSearch,
114116
topK,
115117
getFilterQuery(),
116118
getSeedQuery(),
117-
getEarlyTerminationParams());
119+
getEarlyTerminationParams(),
120+
filteredSearchThreshold);
118121
}
119122
}

solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
import java.util.Map;
2626
import org.apache.lucene.index.VectorEncoding;
2727
import org.apache.lucene.index.VectorSimilarityFunction;
28+
import org.apache.lucene.search.BooleanQuery;
29+
import org.apache.lucene.search.KnnByteVectorQuery;
30+
import org.apache.lucene.search.KnnFloatVectorQuery;
31+
import org.apache.lucene.search.PatienceKnnVectorQuery;
32+
import org.apache.lucene.search.Query;
33+
import org.apache.lucene.search.SeededKnnVectorQuery;
34+
import org.apache.lucene.search.knn.KnnSearchStrategy;
2835
import org.apache.solr.client.solrj.request.JavaBinUpdateRequestCodec;
2936
import org.apache.solr.client.solrj.request.UpdateRequest;
3037
import org.apache.solr.common.SolrException;
@@ -35,6 +42,7 @@
3542
import org.apache.solr.handler.loader.JavabinLoader;
3643
import org.apache.solr.request.SolrQueryRequest;
3744
import org.apache.solr.response.SolrQueryResponse;
45+
import org.apache.solr.search.neural.KnnQParser;
3846
import org.apache.solr.update.CommitUpdateCommand;
3947
import org.apache.solr.update.processor.UpdateRequestProcessor;
4048
import org.apache.solr.update.processor.UpdateRequestProcessorChain;
@@ -838,4 +846,283 @@ public void testIndexingViaJavaBin() throws Exception {
838846
deleteCore();
839847
}
840848
}
849+
850+
@Test
851+
public void testFilteredSearchThreshold_floatNoThresholdInInput_shouldSetDefaultThreshold()
852+
throws Exception {
853+
try {
854+
Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD;
855+
856+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
857+
IndexSchema schema = h.getCore().getLatestSchema();
858+
SchemaField vectorField = schema.getField("vector");
859+
assertNotNull(vectorField);
860+
DenseVectorField type = (DenseVectorField) vectorField.getType();
861+
KnnFloatVectorQuery vectorQuery =
862+
(KnnFloatVectorQuery)
863+
type.getKnnVectorQuery("vector", "[2, 1, 3, 4]", 3, null, null, null, null);
864+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
865+
Integer threshold = strategy.filteredSearchThreshold();
866+
867+
assertEquals(expectedThreshold, threshold);
868+
} finally {
869+
deleteCore();
870+
}
871+
}
872+
873+
@Test
874+
public void testFilteredSearchThreshold_floatThresholdInInput_shouldSetCustomThreshold()
875+
throws Exception {
876+
try {
877+
Integer expectedThreshold = 30;
878+
879+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
880+
IndexSchema schema = h.getCore().getLatestSchema();
881+
SchemaField vectorField = schema.getField("vector");
882+
assertNotNull(vectorField);
883+
DenseVectorField type = (DenseVectorField) vectorField.getType();
884+
KnnFloatVectorQuery vectorQuery =
885+
(KnnFloatVectorQuery)
886+
type.getKnnVectorQuery(
887+
"vector", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold);
888+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
889+
Integer threshold = strategy.filteredSearchThreshold();
890+
891+
assertEquals(expectedThreshold, threshold);
892+
} finally {
893+
deleteCore();
894+
}
895+
}
896+
897+
@Test
898+
public void testFilteredSearchThreshold_seededFloatThresholdInInput_shouldSetCustomThreshold()
899+
throws Exception {
900+
try {
901+
Query seedQuery = new BooleanQuery.Builder().build();
902+
Integer expectedThreshold = 30;
903+
904+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
905+
IndexSchema schema = h.getCore().getLatestSchema();
906+
SchemaField vectorField = schema.getField("vector");
907+
assertNotNull(vectorField);
908+
DenseVectorField type = (DenseVectorField) vectorField.getType();
909+
SeededKnnVectorQuery vectorQuery =
910+
(SeededKnnVectorQuery)
911+
type.getKnnVectorQuery(
912+
"vector", "[2, 1, 3, 4]", 3, null, seedQuery, null, expectedThreshold);
913+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
914+
Integer threshold = strategy.filteredSearchThreshold();
915+
916+
assertEquals(expectedThreshold, threshold);
917+
} finally {
918+
deleteCore();
919+
}
920+
}
921+
922+
@Test
923+
public void
924+
testFilteredSearchThreshold_earlyTerminationFloatThresholdInInput_shouldSetCustomThreshold()
925+
throws Exception {
926+
try {
927+
KnnQParser.EarlyTerminationParams earlyTermination =
928+
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
929+
Integer expectedThreshold = 30;
930+
931+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
932+
IndexSchema schema = h.getCore().getLatestSchema();
933+
SchemaField vectorField = schema.getField("vector");
934+
assertNotNull(vectorField);
935+
DenseVectorField type = (DenseVectorField) vectorField.getType();
936+
PatienceKnnVectorQuery vectorQuery =
937+
(PatienceKnnVectorQuery)
938+
type.getKnnVectorQuery(
939+
"vector", "[2, 1, 3, 4]", 3, null, null, earlyTermination, expectedThreshold);
940+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
941+
Integer threshold = strategy.filteredSearchThreshold();
942+
943+
assertEquals(expectedThreshold, threshold);
944+
} finally {
945+
deleteCore();
946+
}
947+
}
948+
949+
@Test
950+
public void
951+
testFilteredSearchThreshold_seededAndEarlyTerminationFloatThresholdInInput_shouldSetCustomThreshold()
952+
throws Exception {
953+
try {
954+
Query seedQuery = new BooleanQuery.Builder().build();
955+
KnnQParser.EarlyTerminationParams earlyTermination =
956+
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
957+
Integer expectedThreshold = 30;
958+
959+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
960+
IndexSchema schema = h.getCore().getLatestSchema();
961+
SchemaField vectorField = schema.getField("vector");
962+
assertNotNull(vectorField);
963+
DenseVectorField type = (DenseVectorField) vectorField.getType();
964+
PatienceKnnVectorQuery vectorQuery =
965+
(PatienceKnnVectorQuery)
966+
type.getKnnVectorQuery(
967+
"vector",
968+
"[2, 1, 3, 4]",
969+
3,
970+
null,
971+
seedQuery,
972+
earlyTermination,
973+
expectedThreshold);
974+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
975+
Integer threshold = strategy.filteredSearchThreshold();
976+
977+
assertEquals(expectedThreshold, threshold);
978+
} finally {
979+
deleteCore();
980+
}
981+
}
982+
983+
@Test
984+
public void testFilteredSearchThreshold_byteNoThresholdInInput_shouldSetDefaultThreshold()
985+
throws Exception {
986+
try {
987+
Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD;
988+
989+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
990+
IndexSchema schema = h.getCore().getLatestSchema();
991+
SchemaField vectorField = schema.getField("vector_byte_encoding");
992+
assertNotNull(vectorField);
993+
DenseVectorField type = (DenseVectorField) vectorField.getType();
994+
KnnByteVectorQuery vectorQuery =
995+
(KnnByteVectorQuery)
996+
type.getKnnVectorQuery(
997+
"vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, null);
998+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
999+
Integer threshold = strategy.filteredSearchThreshold();
1000+
1001+
assertEquals(expectedThreshold, threshold);
1002+
} finally {
1003+
deleteCore();
1004+
}
1005+
}
1006+
1007+
@Test
1008+
public void testFilteredSearchThreshold_byteThresholdInInput_shouldSetCustomThreshold()
1009+
throws Exception {
1010+
try {
1011+
Integer expectedThreshold = 30;
1012+
1013+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
1014+
IndexSchema schema = h.getCore().getLatestSchema();
1015+
SchemaField vectorField = schema.getField("vector_byte_encoding");
1016+
assertNotNull(vectorField);
1017+
DenseVectorField type = (DenseVectorField) vectorField.getType();
1018+
KnnByteVectorQuery vectorQuery =
1019+
(KnnByteVectorQuery)
1020+
type.getKnnVectorQuery(
1021+
"vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold);
1022+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
1023+
Integer threshold = strategy.filteredSearchThreshold();
1024+
1025+
assertEquals(expectedThreshold, threshold);
1026+
} finally {
1027+
deleteCore();
1028+
}
1029+
}
1030+
1031+
@Test
1032+
public void testFilteredSearchThreshold_seededByteThresholdInInput_shouldSetCustomThreshold()
1033+
throws Exception {
1034+
try {
1035+
Query seedQuery = new BooleanQuery.Builder().build();
1036+
Integer expectedThreshold = 30;
1037+
1038+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
1039+
IndexSchema schema = h.getCore().getLatestSchema();
1040+
SchemaField vectorField = schema.getField("vector_byte_encoding");
1041+
assertNotNull(vectorField);
1042+
DenseVectorField type = (DenseVectorField) vectorField.getType();
1043+
SeededKnnVectorQuery vectorQuery =
1044+
(SeededKnnVectorQuery)
1045+
type.getKnnVectorQuery(
1046+
"vector_byte_encoding",
1047+
"[2, 1, 3, 4]",
1048+
3,
1049+
null,
1050+
seedQuery,
1051+
null,
1052+
expectedThreshold);
1053+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
1054+
Integer threshold = strategy.filteredSearchThreshold();
1055+
1056+
assertEquals(expectedThreshold, threshold);
1057+
} finally {
1058+
deleteCore();
1059+
}
1060+
}
1061+
1062+
@Test
1063+
public void
1064+
testFilteredSearchThreshold_earlyTerminationByteThresholdInInput_shouldSetCustomThreshold()
1065+
throws Exception {
1066+
try {
1067+
KnnQParser.EarlyTerminationParams earlyTermination =
1068+
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
1069+
Integer expectedThreshold = 30;
1070+
1071+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
1072+
IndexSchema schema = h.getCore().getLatestSchema();
1073+
SchemaField vectorField = schema.getField("vector_byte_encoding");
1074+
assertNotNull(vectorField);
1075+
DenseVectorField type = (DenseVectorField) vectorField.getType();
1076+
PatienceKnnVectorQuery vectorQuery =
1077+
(PatienceKnnVectorQuery)
1078+
type.getKnnVectorQuery(
1079+
"vector_byte_encoding",
1080+
"[2, 1, 3, 4]",
1081+
3,
1082+
null,
1083+
null,
1084+
earlyTermination,
1085+
expectedThreshold);
1086+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
1087+
Integer threshold = strategy.filteredSearchThreshold();
1088+
1089+
assertEquals(expectedThreshold, threshold);
1090+
} finally {
1091+
deleteCore();
1092+
}
1093+
}
1094+
1095+
@Test
1096+
public void
1097+
testFilteredSearchThreshold_seededAndEarlyTerminationByteThresholdInInput_shouldSetCustomThreshold()
1098+
throws Exception {
1099+
try {
1100+
Query seedQuery = new BooleanQuery.Builder().build();
1101+
KnnQParser.EarlyTerminationParams earlyTermination =
1102+
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
1103+
Integer expectedThreshold = 30;
1104+
1105+
initCore("solrconfig-basic.xml", "schema-densevector.xml");
1106+
IndexSchema schema = h.getCore().getLatestSchema();
1107+
SchemaField vectorField = schema.getField("vector_byte_encoding");
1108+
assertNotNull(vectorField);
1109+
DenseVectorField type = (DenseVectorField) vectorField.getType();
1110+
PatienceKnnVectorQuery vectorQuery =
1111+
(PatienceKnnVectorQuery)
1112+
type.getKnnVectorQuery(
1113+
"vector_byte_encoding",
1114+
"[2, 1, 3, 4]",
1115+
3,
1116+
null,
1117+
seedQuery,
1118+
earlyTermination,
1119+
expectedThreshold);
1120+
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
1121+
Integer threshold = strategy.filteredSearchThreshold();
1122+
1123+
assertEquals(expectedThreshold, threshold);
1124+
} finally {
1125+
deleteCore();
1126+
}
1127+
}
8411128
}

0 commit comments

Comments
 (0)