Skip to content

Commit 5912eab

Browse files
authored
Pull out common subclasses for our custom flat & hnsw vector formats (#132663)
1 parent a4e3b0a commit 5912eab

11 files changed

+226
-256
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
13+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
14+
import org.elasticsearch.core.SuppressForbidden;
15+
16+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
17+
18+
public abstract class AbstractFlatVectorsFormat extends FlatVectorsFormat {
19+
20+
public static final boolean USE_DIRECT_IO = getUseDirectIO();
21+
22+
@SuppressForbidden(
23+
reason = "TODO Deprecate any lenient usage of Boolean#parseBoolean https://github.com/elastic/elasticsearch/issues/128993"
24+
)
25+
private static boolean getUseDirectIO() {
26+
return Boolean.parseBoolean(System.getProperty("vector.rescoring.directio", "false"));
27+
}
28+
29+
protected AbstractFlatVectorsFormat(String name) {
30+
super(name);
31+
}
32+
33+
protected abstract FlatVectorsScorer flatVectorsScorer();
34+
35+
@Override
36+
public int getMaxDimensions(String fieldName) {
37+
return MAX_DIMS_COUNT;
38+
}
39+
40+
@Override
41+
public String toString() {
42+
return getName() + "(name=" + getName() + ", flatVectorScorer=" + flatVectorsScorer() + ")";
43+
}
44+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.codecs.KnnVectorsFormat;
13+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
14+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
15+
import org.apache.lucene.search.TaskExecutor;
16+
import org.apache.lucene.util.hnsw.HnswGraph;
17+
18+
import java.util.concurrent.ExecutorService;
19+
20+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
21+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
22+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
23+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
24+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
25+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
26+
27+
public abstract class AbstractHnswVectorsFormat extends KnnVectorsFormat {
28+
29+
/**
30+
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
31+
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
32+
*/
33+
protected final int maxConn;
34+
35+
/**
36+
* The number of candidate neighbors to track while searching the graph for each newly inserted
37+
* node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
38+
* for details.
39+
*/
40+
protected final int beamWidth;
41+
42+
protected final int numMergeWorkers;
43+
protected final TaskExecutor mergeExec;
44+
45+
/** Constructs a format using default graph construction parameters */
46+
protected AbstractHnswVectorsFormat(String name) {
47+
this(name, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
48+
}
49+
50+
/**
51+
* Constructs a format using the given graph construction parameters.
52+
*
53+
* @param maxConn the maximum number of connections to a node in the HNSW graph
54+
* @param beamWidth the size of the queue maintained during graph construction.
55+
*/
56+
protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth) {
57+
this(name, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
58+
}
59+
60+
/**
61+
* Constructs a format using the given graph construction parameters and scalar quantization.
62+
*
63+
* @param maxConn the maximum number of connections to a node in the HNSW graph
64+
* @param beamWidth the size of the queue maintained during graph construction.
65+
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
66+
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
67+
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
68+
* generated by this format to do the merge
69+
*/
70+
protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
71+
super(name);
72+
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
73+
throw new IllegalArgumentException(
74+
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
75+
);
76+
}
77+
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
78+
throw new IllegalArgumentException(
79+
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
80+
);
81+
}
82+
this.maxConn = maxConn;
83+
this.beamWidth = beamWidth;
84+
if (numMergeWorkers == 1 && mergeExec != null) {
85+
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
86+
}
87+
this.numMergeWorkers = numMergeWorkers;
88+
if (mergeExec != null) {
89+
this.mergeExec = new TaskExecutor(mergeExec);
90+
} else {
91+
this.mergeExec = null;
92+
}
93+
}
94+
95+
protected abstract FlatVectorsFormat flatVectorsFormat();
96+
97+
@Override
98+
public int getMaxDimensions(String fieldName) {
99+
return MAX_DIMS_COUNT;
100+
}
101+
102+
@Override
103+
public String toString() {
104+
return getName()
105+
+ "(name="
106+
+ getName()
107+
+ ", maxConn="
108+
+ maxConn
109+
+ ", beamWidth="
110+
+ beamWidth
111+
+ ", flatVectorFormat="
112+
+ flatVectorsFormat()
113+
+ ")";
114+
}
115+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
import org.apache.lucene.codecs.KnnVectorsFormat;
1312
import org.apache.lucene.codecs.KnnVectorsReader;
1413
import org.apache.lucene.codecs.KnnVectorsWriter;
1514
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
@@ -22,19 +21,11 @@
2221

2322
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
2423
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
25-
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
2624

27-
public final class ES814HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
25+
public final class ES814HnswScalarQuantizedVectorsFormat extends AbstractHnswVectorsFormat {
2826

2927
static final String NAME = "ES814HnswScalarQuantizedVectorsFormat";
3028

31-
static final int MAXIMUM_MAX_CONN = 512;
32-
static final int MAXIMUM_BEAM_WIDTH = 3200;
33-
34-
private final int maxConn;
35-
36-
private final int beamWidth;
37-
3829
/** The format for storing, reading, merging vectors on disk */
3930
private final FlatVectorsFormat flatVectorsFormat;
4031

@@ -43,45 +34,22 @@ public ES814HnswScalarQuantizedVectorsFormat() {
4334
}
4435

4536
public ES814HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth, Float confidenceInterval, int bits, boolean compress) {
46-
super(NAME);
47-
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
48-
throw new IllegalArgumentException(
49-
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
50-
);
51-
}
52-
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
53-
throw new IllegalArgumentException(
54-
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
55-
);
56-
}
57-
this.maxConn = maxConn;
58-
this.beamWidth = beamWidth;
37+
super(NAME, maxConn, beamWidth);
5938
this.flatVectorsFormat = new ES814ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
6039
}
6140

6241
@Override
63-
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
64-
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
65-
}
66-
67-
@Override
68-
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
69-
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
42+
protected FlatVectorsFormat flatVectorsFormat() {
43+
return flatVectorsFormat;
7044
}
7145

7246
@Override
73-
public int getMaxDimensions(String fieldName) {
74-
return MAX_DIMS_COUNT;
47+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
48+
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
7549
}
7650

7751
@Override
78-
public String toString() {
79-
return "ES814HnswScalarQuantizedVectorsFormat(name=ES814HnswScalarQuantizedVectorsFormat, maxConn="
80-
+ maxConn
81-
+ ", beamWidth="
82-
+ beamWidth
83-
+ ", flatVectorFormat="
84-
+ flatVectorsFormat
85-
+ ")";
52+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
53+
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
8654
}
8755
}

server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
import org.apache.lucene.codecs.KnnVectorsFormat;
1312
import org.apache.lucene.codecs.KnnVectorsReader;
1413
import org.apache.lucene.codecs.KnnVectorsWriter;
1514
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
@@ -20,63 +19,32 @@
2019

2120
import java.io.IOException;
2221

23-
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
24-
25-
public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {
22+
public class ES815HnswBitVectorsFormat extends AbstractHnswVectorsFormat {
2623

2724
static final String NAME = "ES815HnswBitVectorsFormat";
2825

29-
static final int MAXIMUM_MAX_CONN = 512;
30-
static final int MAXIMUM_BEAM_WIDTH = 3200;
31-
32-
private final int maxConn;
33-
private final int beamWidth;
34-
3526
private static final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat();
3627

3728
public ES815HnswBitVectorsFormat() {
38-
this(16, 100);
39-
}
40-
41-
public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
4229
super(NAME);
43-
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
44-
throw new IllegalArgumentException(
45-
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
46-
);
47-
}
48-
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
49-
throw new IllegalArgumentException(
50-
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
51-
);
52-
}
53-
this.maxConn = maxConn;
54-
this.beamWidth = beamWidth;
5530
}
5631

57-
@Override
58-
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
59-
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
32+
public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
33+
super(NAME, maxConn, beamWidth);
6034
}
6135

6236
@Override
63-
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
64-
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
37+
protected FlatVectorsFormat flatVectorsFormat() {
38+
return flatVectorsFormat;
6539
}
6640

6741
@Override
68-
public String toString() {
69-
return "ES815HnswBitVectorsFormat(name=ES815HnswBitVectorsFormat, maxConn="
70-
+ maxConn
71-
+ ", beamWidth="
72-
+ beamWidth
73-
+ ", flatVectorFormat="
74-
+ flatVectorsFormat
75-
+ ")";
42+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
43+
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
7644
}
7745

7846
@Override
79-
public int getMaxDimensions(String fieldName) {
80-
return MAX_DIMS_COUNT;
47+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
48+
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
8149
}
8250
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java renamed to server/src/main/java/org/elasticsearch/index/codec/vectors/MergeReaderWrapper.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
* License v3.0 only", or the "Server Side Public License, v 1".
88
*/
99

10-
package org.elasticsearch.index.codec.vectors.es818;
10+
package org.elasticsearch.index.codec.vectors;
1111

1212
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
1313
import org.apache.lucene.index.ByteVectorValues;
@@ -25,19 +25,19 @@
2525
import java.util.Collection;
2626
import java.util.Map;
2727

28-
class MergeReaderWrapper extends FlatVectorsReader implements OffHeapStats {
28+
public class MergeReaderWrapper extends FlatVectorsReader implements OffHeapStats {
2929

3030
private final FlatVectorsReader mainReader;
3131
private final FlatVectorsReader mergeReader;
3232

33-
protected MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) {
33+
public MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) {
3434
super(mainReader.getFlatVectorScorer());
3535
this.mainReader = mainReader;
3636
this.mergeReader = mergeReader;
3737
}
3838

3939
// For testing
40-
FlatVectorsReader getMainReader() {
40+
public FlatVectorsReader getMainReader() {
4141
return mainReader;
4242
}
4343

0 commit comments

Comments
 (0)