Skip to content

Commit 4a7c9c2

Browse files
authored
Bypass HNSW graph building for tiny segments (#14963)
This change avoids creating a HNSW graph if the segment is small (here we have taken the thresholdfor number of vectors as `10000` based on the conversation [here](#13447 (comment)) for now). Some of the points I'm not sure how we would want to go about : - All the tests passes currently since the option to enable the optimization is `false` by default but setting it to `true` reveals some failing unit tests which inherently assumes that the HNSW graph is created and KNN search is triggered (do we have some idea of how to bypass those in some good clean way?) - I understand we might want to always keep this optimization on (also less invasive change), but for now in this PR, I made it configurable and enabled it on the KNN format - just to be cautious (wasn't sure if it would not affect back-compact in some unknown way), but happy to make it as default behaviour
1 parent 78dd964 commit 4a7c9c2

18 files changed

+305
-46
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ Optimizations
203203
* GITHUB#15343: Ensure that `AcceptDocs#cost()` only ever calls `BitSets#cardinality()`
204204
once per instance to avoid redundant computation. (Ben Trent)
205205

206+
* GITHUB#14963: Bypass HNSW graph building for tiny segments. (Shubham Chaudhary, Ben Trent)
207+
206208
Bug Fixes
207209
---------------------
208210
* GITHUB#14161: PointInSetQuery's constructor now throws IllegalArgumentException

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWV0HnswScalarQuantizationVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
4747
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
4848
flatVectorsFormat.fieldsWriter(state),
4949
1,
50-
null);
50+
null,
51+
0);
5152
}
5253

5354
static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {

lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWV1HnswScalarQuantizationVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
5757
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
5858
flatVectorsFormat.fieldsWriter(state),
5959
1,
60-
null);
60+
null,
61+
0);
6162
}
6263
}

lucene/core/src/java/org/apache/lucene/codecs/lucene102/Lucene102HnswBinaryQuantizedVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
128128
beamWidth,
129129
flatVectorsFormat.fieldsWriter(state),
130130
numMergeWorkers,
131-
mergeExec);
131+
mergeExec,
132+
0);
132133
}
133134

134135
@Override

lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104HnswScalarQuantizedVectorsFormat.java

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
2020
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
2121
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
22+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.HNSW_GRAPH_THRESHOLD;
2223
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
2324
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
2425

@@ -60,6 +61,20 @@ public class Lucene104HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat
6061
/** The format for storing, reading, merging vectors on disk */
6162
private final Lucene104ScalarQuantizedVectorsFormat flatVectorsFormat;
6263

64+
/**
65+
* The threshold to use to bypass HNSW graph building for tiny segments in terms of k for a graph
66+
* i.e. number of docs to match the query (default is {@link
67+
* Lucene99HnswVectorsFormat#HNSW_GRAPH_THRESHOLD}).
68+
*
69+
* <ul>
70+
* <li>0 indicates that the graph is always built.
71+
* <li>greater than 0 indicates that the graph needs a certain number of nodes before it starts
72+
* building. See {@link Lucene99HnswVectorsFormat#HNSW_GRAPH_THRESHOLD} for details.
73+
* <li>Negative values aren't allowed.
74+
* </ul>
75+
*/
76+
private final int tinySegmentsThreshold;
77+
6378
private final int numMergeWorkers;
6479
private final TaskExecutor mergeExec;
6580

@@ -70,7 +85,8 @@ public Lucene104HnswScalarQuantizedVectorsFormat() {
7085
DEFAULT_MAX_CONN,
7186
DEFAULT_BEAM_WIDTH,
7287
DEFAULT_NUM_MERGE_WORKER,
73-
null);
88+
null,
89+
HNSW_GRAPH_THRESHOLD);
7490
}
7591

7692
/**
@@ -80,7 +96,13 @@ public Lucene104HnswScalarQuantizedVectorsFormat() {
8096
* @param beamWidth the size of the queue maintained during graph construction.
8197
*/
8298
public Lucene104HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
83-
this(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
99+
this(
100+
ScalarEncoding.UNSIGNED_BYTE,
101+
maxConn,
102+
beamWidth,
103+
DEFAULT_NUM_MERGE_WORKER,
104+
null,
105+
HNSW_GRAPH_THRESHOLD);
84106
}
85107

86108
/**
@@ -100,6 +122,26 @@ public Lucene104HnswScalarQuantizedVectorsFormat(
100122
int beamWidth,
101123
int numMergeWorkers,
102124
ExecutorService mergeExec) {
125+
this(encoding, maxConn, beamWidth, numMergeWorkers, mergeExec, HNSW_GRAPH_THRESHOLD);
126+
}
127+
128+
/**
129+
* Constructs a format using the given graph construction parameters and scalar quantization.
130+
*
131+
* @param maxConn the maximum number of connections to a node in the HNSW graph
132+
* @param beamWidth the size of the queue maintained during graph construction.
133+
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
134+
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
135+
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
136+
* generated by this format to do the merge
137+
*/
138+
public Lucene104HnswScalarQuantizedVectorsFormat(
139+
ScalarEncoding encoding,
140+
int maxConn,
141+
int beamWidth,
142+
int numMergeWorkers,
143+
ExecutorService mergeExec,
144+
int tinySegmentsThreshold) {
103145
super(NAME);
104146
flatVectorsFormat = new Lucene104ScalarQuantizedVectorsFormat(encoding);
105147
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
@@ -118,6 +160,7 @@ public Lucene104HnswScalarQuantizedVectorsFormat(
118160
}
119161
this.maxConn = maxConn;
120162
this.beamWidth = beamWidth;
163+
this.tinySegmentsThreshold = tinySegmentsThreshold;
121164
if (numMergeWorkers == 1 && mergeExec != null) {
122165
throw new IllegalArgumentException(
123166
"No executor service is needed as we'll use single thread to merge");
@@ -138,7 +181,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
138181
beamWidth,
139182
flatVectorsFormat.fieldsWriter(state),
140183
numMergeWorkers,
141-
mergeExec);
184+
mergeExec,
185+
tinySegmentsThreshold);
142186
}
143187

144188
@Override
@@ -157,6 +201,8 @@ public String toString() {
157201
+ maxConn
158202
+ ", beamWidth="
159203
+ beamWidth
204+
+ ", tinySegmentsThreshold="
205+
+ tinySegmentsThreshold
160206
+ ", flatVectorFormat="
161207
+ flatVectorsFormat
162208
+ ")";

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
116116
/** Default to use single thread merge */
117117
public static final int DEFAULT_NUM_MERGE_WORKER = 1;
118118

119+
/**
120+
* Minimum estimated search effort (in terms of expected visited nodes) required before building
121+
* an HNSW graph for a segment.
122+
*
123+
* <p>This threshold is compared against the value produced by {@link
124+
* org.apache.lucene.util.hnsw.HnswGraphSearcher#expectedVisitedNodes(int, int)}, which estimates
125+
* how many nodes would be visited during a vector search based on the current graph size and
126+
* {@code k} (neighbours to find).
127+
*
128+
* <p>If the estimated number of visited nodes falls below this threshold, HNSW graph construction
129+
* is skipped for that segment - typically for small flushes or low document count segments -
130+
* since the overhead of building the graph would outweigh its search benefits.
131+
*
132+
* <p>Default: {@code 100}
133+
*/
134+
public static final int HNSW_GRAPH_THRESHOLD = 100;
135+
119136
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
120137

121138
/**
@@ -138,11 +155,30 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
138155
private final int numMergeWorkers;
139156
private final TaskExecutor mergeExec;
140157

158+
/**
159+
* The threshold to use to bypass HNSW graph building for tiny segments in terms of k for a graph
160+
* i.e. number of docs to match the query (default is {@link
161+
* Lucene99HnswVectorsFormat#HNSW_GRAPH_THRESHOLD}).
162+
*
163+
* <ul>
164+
* <li>0 indicates that the graph is always built.
165+
* <li>0 indicates that the graph needs certain or more nodes before it starts building.
166+
* <li>Negative values aren't allowed.
167+
* </ul>
168+
*/
169+
private final int tinySegmentsThreshold;
170+
141171
private final int writeVersion;
142172

143173
/** Constructs a format using default graph construction parameters */
144174
public Lucene99HnswVectorsFormat() {
145-
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
175+
this(
176+
DEFAULT_MAX_CONN,
177+
DEFAULT_BEAM_WIDTH,
178+
DEFAULT_NUM_MERGE_WORKER,
179+
null,
180+
HNSW_GRAPH_THRESHOLD,
181+
VERSION_CURRENT);
146182
}
147183

148184
/**
@@ -152,7 +188,20 @@ public Lucene99HnswVectorsFormat() {
152188
* @param beamWidth the size of the queue maintained during graph construction.
153189
*/
154190
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
155-
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
191+
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, HNSW_GRAPH_THRESHOLD, VERSION_CURRENT);
192+
}
193+
194+
/**
195+
* Constructs a format using the given graph construction parameters.
196+
*
197+
* @param maxConn the maximum number of connections to a node in the HNSW graph
198+
* @param beamWidth the size of the queue maintained during graph construction.
199+
* @param tinySegmentsThreshold the expected number of vector operations to return k nearest
200+
* neighbors of the current graph size
201+
*/
202+
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth, int tinySegmentsThreshold) {
203+
this(
204+
maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, tinySegmentsThreshold, VERSION_CURRENT);
156205
}
157206

158207
/**
@@ -168,7 +217,7 @@ public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
168217
*/
169218
public Lucene99HnswVectorsFormat(
170219
int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
171-
this(maxConn, beamWidth, numMergeWorkers, mergeExec, VERSION_CURRENT);
220+
this(maxConn, beamWidth, numMergeWorkers, mergeExec, HNSW_GRAPH_THRESHOLD, VERSION_CURRENT);
172221
}
173222

174223
/**
@@ -182,13 +231,38 @@ public Lucene99HnswVectorsFormat(
182231
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
183232
* generated by this format to do the merge. If null, the configured {@link
184233
* MergeScheduler#getIntraMergeExecutor(MergePolicy.OneMerge)} is used.
234+
* @param tinySegmentsThreshold the expected number of vector operations to return k nearest
235+
* neighbors of the current graph size
236+
*/
237+
public Lucene99HnswVectorsFormat(
238+
int maxConn,
239+
int beamWidth,
240+
int numMergeWorkers,
241+
ExecutorService mergeExec,
242+
int tinySegmentsThreshold) {
243+
this(maxConn, beamWidth, numMergeWorkers, mergeExec, tinySegmentsThreshold, VERSION_CURRENT);
244+
}
245+
246+
/**
247+
* Constructs a format using the given graph construction parameters.
248+
*
249+
* @param maxConn the maximum number of connections to a node in the HNSW graph
250+
* @param beamWidth the size of the queue maintained during graph construction.
251+
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
252+
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
253+
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
254+
* generated by this format to do the merge. If null, the configured {@link
255+
* MergeScheduler#getIntraMergeExecutor(MergePolicy.OneMerge)} is used.
256+
* @param tinySegmentsThreshold the expected number of vector operations to return k nearest
257+
* neighbors of the current graph size
185258
* @param writeVersion the version used for the writer to encode docID's (VarInt=0, GroupVarInt=1)
186259
*/
187260
Lucene99HnswVectorsFormat(
188261
int maxConn,
189262
int beamWidth,
190263
int numMergeWorkers,
191264
ExecutorService mergeExec,
265+
int tinySegmentsThreshold,
192266
int writeVersion) {
193267
super("Lucene99HnswVectorsFormat");
194268
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
@@ -207,6 +281,7 @@ public Lucene99HnswVectorsFormat(
207281
}
208282
this.maxConn = maxConn;
209283
this.beamWidth = beamWidth;
284+
this.tinySegmentsThreshold = tinySegmentsThreshold;
210285
this.writeVersion = writeVersion;
211286
if (numMergeWorkers == 1 && mergeExec != null) {
212287
throw new IllegalArgumentException(
@@ -229,6 +304,7 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
229304
flatVectorsFormat.fieldsWriter(state),
230305
numMergeWorkers,
231306
mergeExec,
307+
tinySegmentsThreshold,
232308
writeVersion);
233309
}
234310

@@ -248,6 +324,8 @@ public String toString() {
248324
+ maxConn
249325
+ ", beamWidth="
250326
+ beamWidth
327+
+ ", tinySegmentsThreshold="
328+
+ tinySegmentsThreshold
251329
+ ", flatVectorFormat="
252330
+ flatVectorsFormat
253331
+ ")";

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,18 +332,18 @@ private void search(
332332
final RandomVectorScorer scorer = scorerSupplier.get();
333333
final KnnCollector collector =
334334
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
335-
HnswGraph graph = getGraph(fieldEntry);
336335
// Take into account if quantized? E.g. some scorer cost?
337336
// Use approximate cardinality as this is good enough, but ensure we don't exceed the graph
338337
// size as that is illogical
338+
HnswGraph graph = getGraph(fieldEntry);
339339
int filteredDocCount = Math.min(acceptDocs.cost(), graph.size());
340340
Bits accepted = acceptDocs.bits();
341341
final Bits acceptedOrds = scorer.getAcceptOrds(accepted);
342342
int numVectors = scorer.maxOrd();
343343
boolean doHnsw = knnCollector.k() < numVectors;
344344
// The approximate number of vectors that would be visited if we did not filter
345345
int unfilteredVisit = HnswGraphSearcher.expectedVisitedNodes(knnCollector.k(), graph.size());
346-
if (unfilteredVisit >= filteredDocCount) {
346+
if (unfilteredVisit >= filteredDocCount || graph.size() == 0) {
347347
doHnsw = false;
348348
}
349349
if (doHnsw) {
@@ -399,6 +399,9 @@ public HnswGraph getGraph(String field) throws IOException {
399399
}
400400

401401
private HnswGraph getGraph(FieldEntry entry) throws IOException {
402+
if (entry.vectorIndexLength == 0) {
403+
return HnswGraph.EMPTY;
404+
}
402405
return new OffHeapHnswGraph(entry, vectorIndex);
403406
}
404407

0 commit comments

Comments
 (0)