Skip to content

Commit 9b957e8

Browse files
shubhamvishubenwtrent
authored andcommitted
Bypass HNSW graph building for tiny segments (#14963)
This change avoids creating a HNSW graph if the segment is small (here we have taken the thresholdfor number of vectors as `10000` based on the conversation [here](#13447 (comment)) for now). Some of the points I'm not sure how we would want to go about : - All the tests passes currently since the option to enable the optimization is `false` by default but setting it to `true` reveals some failing unit tests which inherently assumes that the HNSW graph is created and KNN search is triggered (do we have some idea of how to bypass those in some good clean way?) - I understand we might want to always keep this optimization on (also less invasive change), but for now in this PR, I made it configurable and enabled it on the KNN format - just to be cautious (wasn't sure if it would not affect back-compact in some unknown way), but happy to make it as default behaviour
1 parent 3eea6aa commit 9b957e8

17 files changed

+298
-42
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ Optimizations
7575
* GITHUB#15343: Ensure that `AcceptDocs#cost()` only ever calls `BitSets#cardinality()`
7676
once per instance to avoid redundant computation. (Ben Trent)
7777

78+
* GITHUB#14963: Bypass HNSW graph building for tiny segments. (Shubham Chaudhary, Ben Trent)
79+
7880
Bug Fixes
7981
---------------------
8082
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWV0HnswScalarQuantizationVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
4747
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
4848
flatVectorsFormat.fieldsWriter(state),
4949
1,
50-
null);
50+
null,
51+
0);
5152
}
5253

5354
static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWV1HnswScalarQuantizationVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
5757
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
5858
flatVectorsFormat.fieldsWriter(state),
5959
1,
60-
null);
60+
null,
61+
0);
6162
}
6263
}

lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
128128
beamWidth,
129129
flatVectorsFormat.fieldsWriter(state),
130130
numMergeWorkers,
131-
mergeExec);
131+
mergeExec,
132+
0);
132133
}
133134

134135
@Override

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
2020
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
2121
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
22+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.HNSW_GRAPH_THRESHOLD;
2223
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
2324
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
2425

@@ -60,6 +61,20 @@ public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat
6061
/** The format for storing, reading, merging vectors on disk */
6162
private final Lucene104ScalarQuantizedVectorsFormat flatVectorsFormat;
6263

64+
/**
65+
* The threshold to use to bypass HNSW graph building for tiny segments in terms of k for a graph
66+
* i.e. number of docs to match the query (default is {@link
67+
* Lucene99HnswVectorsFormat#HNSW_GRAPH_THRESHOLD}).
68+
*
69+
* <ul>
70+
* <li>0 indicates that the graph is always built.
71+
* <li>greater than 0 indicates that the graph needs a certain number of nodes before it starts
72+
* building. See {@link Lucene99HnswVectorsFormat#HNSW_GRAPH_THRESHOLD} for details.
73+
* <li>Negative values aren't allowed.
74+
* </ul>
75+
*/
76+
private final int tinySegmentsThreshold;
77+
6378
private final int numMergeWorkers;
6479
private final TaskExecutor mergeExec;
6580

@@ -70,7 +85,8 @@ public Lucene104HnswScalarQuantizedVectorsFormat() {
7085
DEFAULT_MAX_CONN,
7186
DEFAULT_BEAM_WIDTH,
7287
DEFAULT_NUM_MERGE_WORKER,
73-
null);
88+
null,
89+
HNSW_GRAPH_THRESHOLD);
7490
}
7591

7692
/**
@@ -80,7 +96,13 @@ public Lucene104HnswScalarQuantizedVectorsFormat() {
8096
* @param beamWidth the size of the queue maintained during graph construction.
8197
*/
8298
public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
83-
this(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
99+
this(
100+
ScalarEncoding.UNSIGNED_BYTE,
101+
maxConn,
102+
beamWidth,
103+
DEFAULT_NUM_MERGE_WORKER,
104+
null,
105+
HNSW_GRAPH_THRESHOLD);
84106
}
85107

86108
/**
@@ -99,6 +121,26 @@ public Lucene104HnswScalarQuantizedVectorsFormat(
99121
int beamWidth,
100122
int numMergeWorkers,
101123
ExecutorService mergeExec) {
124+
this(encoding, maxConn, beamWidth, numMergeWorkers, mergeExec, HNSW_GRAPH_THRESHOLD);
125+
}
126+
127+
/**
128+
* Constructs a format using the given graph construction parameters and scalar quantization.
129+
*
130+
* @param maxConn the maximum number of connections to a node in the HNSW graph
131+
* @param beamWidth the size of the queue maintained during graph construction.
132+
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
133+
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
134+
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
135+
* generated by this format to do the merge
136+
*/
137+
public Lucene104HnswScalarQuantizedVectorsFormat(
138+
ScalarEncoding encoding,
139+
int maxConn,
140+
int beamWidth,
141+
int numMergeWorkers,
142+
ExecutorService mergeExec,
143+
int tinySegmentsThreshold) {
102144
super(NAME);
103145
flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding);
104146
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
@@ -117,6 +159,7 @@ public Lucene104HnswScalarQuantizedVectorsFormat(
117159
}
118160
this.maxConn = maxConn;
119161
this.beamWidth = beamWidth;
162+
this.tinySegmentsThreshold = tinySegmentsThreshold;
120163
if (numMergeWorkers == 1 && mergeExec != null) {
121164
throw new IllegalArgumentException(
122165
"No executor service is needed as we'll use single thread to merge");
@@ -137,7 +180,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
137180
beamWidth,
138181
flatVectorsFormat.fieldsWriter(state),
139182
numMergeWorkers,
140-
mergeExec);
183+
mergeExec,
184+
tinySegmentsThreshold);
141185
}
142186

143187
@Override
@@ -156,6 +200,8 @@ public String toString() {
156200
+ maxConn
157201
+ ", beamWidth="
158202
+ beamWidth
203+
+ ", tinySegmentsThreshold="
204+
+ tinySegmentsThreshold
159205
+ ", flatVectorFormat="
160206
+ flatVectorsFormat
161207
+ ")";

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
116116
/** Default to use single thread merge */
117117
public static final int DEFAULT_NUM_MERGE_WORKER = 1;
118118

119+
/**
120+
* Minimum estimated search effort (in terms of expected visited nodes) required before building
121+
* an HNSW graph for a segment.
122+
*
123+
* <p>This threshold is compared against the value produced by {@link
124+
* org.apache.lucene.util.hnsw.HnswGraphSearcher#expectedVisitedNodes(int, int)}, which estimates
125+
* how many nodes would be visited during a vector search based on the current graph size and
126+
* {@code k} (neighbours to find).
127+
*
128+
* <p>If the estimated number of visited nodes falls below this threshold, HNSW graph construction
129+
* is skipped for that segment - typically for small flushes or low document count segments -
130+
* since the overhead of building the graph would outweigh its search benefits.
131+
*
132+
* <p>Default: {@code 100}
133+
*/
134+
public static final int HNSW_GRAPH_THRESHOLD = 100;
135+
119136
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
120137

121138
/**
@@ -138,11 +155,30 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
138155
private final int numMergeWorkers;
139156
private final TaskExecutor mergeExec;
140157

158+
/**
159+
* The threshold to use to bypass HNSW graph building for tiny segments in terms of k for a graph
160+
* i.e. number of docs to match the query (default is {@link
161+
* Lucene99HnswVectorsFormat#HNSW_GRAPH_THRESHOLD}).
162+
*
163+
* <ul>
164+
* <li>0 indicates that the graph is always built.
165+
* <li>0 indicates that the graph needs certain or more nodes before it starts building.
166+
* <li>Negative values aren't allowed.
167+
* </ul>
168+
*/
169+
private final int tinySegmentsThreshold;
170+
141171
private final int writeVersion;
142172

143173
/** Constructs a format using default graph construction parameters */
144174
public Lucene99HnswVectorsFormat() {
145-
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
175+
this(
176+
DEFAULT_MAX_CONN,
177+
DEFAULT_BEAM_WIDTH,
178+
DEFAULT_NUM_MERGE_WORKER,
179+
null,
180+
HNSW_GRAPH_THRESHOLD,
181+
VERSION_CURRENT);
146182
}
147183

148184
/**
@@ -152,7 +188,20 @@ public Lucene99HnswVectorsFormat() {
152188
* @param beamWidth the size of the queue maintained during graph construction.
153189
*/
154190
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
155-
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
191+
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD, VERSION_CURRENT);
192+
}
193+
194+
/**
195+
* Constructs a format using the given graph construction parameters.
196+
*
197+
* @param maxConn the maximum number of connections to a node in the HNSW graph
198+
* @param beamWidth the size of the queue maintained during graph construction.
199+
* @param tinySegmentsThreshold the expected number of vector operations to return k nearest
200+
* neighbors of the current graph size
201+
*/
202+
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth, int tinySegmentsThreshold) {
203+
this(
204+
maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, tinySegmentsThreshold, VERSION_CURRENT);
156205
}
157206

158207
/**
@@ -168,7 +217,7 @@ public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
168217
*/
169218
public Lucene99HnswVectorsFormat(
170219
int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
171-
this(maxConn, beamWidth, numMergeWorkers, mergeExec, VERSION_CURRENT);
220+
this(maxConn, beamWidth, numMergeWorkers, mergeExec, HNSW_GRAPH_THRESHOLD, VERSION_CURRENT);
172221
}
173222

174223
/**
@@ -182,13 +231,38 @@ public Lucene99HnswVectorsFormat(
182231
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
183232
* generated by this format to do the merge. If null, the configured {@link
184233
* MergeScheduler#getIntraMergeExecutor(MergePolicy.OneMerge)} is used.
234+
* @param tinySegmentsThreshold the expected number of vector operations to return k nearest
235+
* neighbors of the current graph size
236+
*/
237+
public Lucene99HnswVectorsFormat(
238+
int maxConn,
239+
int beamWidth,
240+
int numMergeWorkers,
241+
ExecutorService mergeExec,
242+
int tinySegmentsThreshold) {
243+
this(maxConn, beamWidth, numMergeWorkers, mergeExec, tinySegmentsThreshold, VERSION_CURRENT);
244+
}
245+
246+
/**
247+
* Constructs a format using the given graph construction parameters.
248+
*
249+
* @param maxConn the maximum number of connections to a node in the HNSW graph
250+
* @param beamWidth the size of the queue maintained during graph construction.
251+
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
252+
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
253+
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
254+
* generated by this format to do the merge. If null, the configured {@link
255+
* MergeScheduler#getIntraMergeExecutor(MergePolicy.OneMerge)} is used.
256+
* @param tinySegmentsThreshold the expected number of vector operations to return k nearest
257+
* neighbors of the current graph size
185258
* @param writeVersion the version used for the writer to encode docID's (VarInt=0, GroupVarInt=1)
186259
*/
187260
Lucene99HnswVectorsFormat(
188261
int maxConn,
189262
int beamWidth,
190263
int numMergeWorkers,
191264
ExecutorService mergeExec,
265+
int tinySegmentsThreshold,
192266
int writeVersion) {
193267
super("Lucene99HnswVectorsFormat");
194268
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
@@ -207,6 +281,7 @@ public Lucene99HnswVectorsFormat(
207281
}
208282
this.maxConn = maxConn;
209283
this.beamWidth = beamWidth;
284+
this.tinySegmentsThreshold = tinySegmentsThreshold;
210285
this.writeVersion = writeVersion;
211286
if (numMergeWorkers == 1 && mergeExec != null) {
212287
throw new IllegalArgumentException(
@@ -229,6 +304,7 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
229304
flatVectorsFormat.fieldsWriter(state),
230305
numMergeWorkers,
231306
mergeExec,
307+
tinySegmentsThreshold,
232308
writeVersion);
233309
}
234310

@@ -248,6 +324,8 @@ public String toString() {
248324
+ maxConn
249325
+ ", beamWidth="
250326
+ beamWidth
327+
+ ", tinySegmentsThreshold="
328+
+ tinySegmentsThreshold
251329
+ ", flatVectorFormat="
252330
+ flatVectorsFormat
253331
+ ")";

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,18 +338,18 @@ private void search(
338338
final RandomVectorScorer scorer = scorerSupplier.get();
339339
final KnnCollector collector =
340340
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
341-
HnswGraph graph = getGraph(fieldEntry);
342341
// Take into account if quantized? E.g. some scorer cost?
343342
// Use approximate cardinality as this is good enough, but ensure we don't exceed the graph
344343
// size as that is illogical
344+
HnswGraph graph = getGraph(fieldEntry);
345345
int filteredDocCount = Math.min(acceptDocs.cost(), graph.size());
346346
Bits accepted = acceptDocs.bits();
347347
final Bits acceptedOrds = scorer.getAcceptOrds(accepted);
348348
int numVectors = scorer.maxOrd();
349349
boolean doHnsw = knnCollector.k() < numVectors;
350350
// The approximate number of vectors that would be visited if we did not filter
351351
int unfilteredVisit = HnswGraphSearcher.expectedVisitedNodes(knnCollector.k(), graph.size());
352-
if (unfilteredVisit >= filteredDocCount) {
352+
if (unfilteredVisit >= filteredDocCount || graph.size() == 0) {
353353
doHnsw = false;
354354
}
355355
if (doHnsw) {
@@ -405,6 +405,9 @@ public HnswGraph getGraph(String field) throws IOException {
405405
}
406406

407407
private HnswGraph getGraph(FieldEntry entry) throws IOException {
408+
if (entry.vectorIndexLength == 0) {
409+
return HnswGraph.EMPTY;
410+
}
408411
return new OffHeapHnswGraph(entry, vectorIndex);
409412
}
410413

0 commit comments

Comments
 (0)