Skip to content

Commit 9c1db2d

Browse files
Expose m and efConstruction params in GPU index building
1 parent dddadce commit 9c1db2d

File tree

5 files changed

+49
-39
lines changed

5 files changed

+49
-39
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2117,7 +2117,7 @@ public boolean updatableTo(DenseVectorIndexOptions update) {
21172117
}
21182118
}
21192119

2120-
static class HnswIndexOptions extends DenseVectorIndexOptions {
2120+
public static class HnswIndexOptions extends DenseVectorIndexOptions {
21212121
private final int m;
21222122
private final int efConstruction;
21232123

@@ -2160,6 +2160,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
21602160
return builder;
21612161
}
21622162

2163+
public int m() {
2164+
return m;
2165+
}
2166+
2167+
public int efConstruction() {
2168+
return efConstruction;
2169+
}
2170+
21632171
@Override
21642172
public boolean doEquals(DenseVectorIndexOptions o) {
21652173
if (this == o) return true;

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
*/
77
package org.elasticsearch.xpack.gpu;
88

9+
import org.apache.lucene.codecs.KnnVectorsFormat;
10+
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
911
import org.elasticsearch.common.util.FeatureFlag;
1012
import org.elasticsearch.index.IndexSettings;
1113
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -37,12 +39,12 @@ public VectorsFormatProvider getVectorsFormatProvider() {
3739
"[index.vectors.indexing.use_gpu] was set to [true], but GPU resources are not accessible on the node."
3840
);
3941
}
40-
return new GPUVectorsFormat();
42+
return getVectorsFormat(indexOptions);
4143
}
4244
if (gpuMode == IndexSettings.GpuMode.AUTO
4345
&& vectorIndexTypeSupported(indexOptions.getType())
4446
&& GPUSupport.isSupported(false)) {
45-
return new GPUVectorsFormat();
47+
return getVectorsFormat(indexOptions);
4648
}
4749
}
4850
return null;
@@ -52,4 +54,19 @@ && vectorIndexTypeSupported(indexOptions.getType())
5254
private boolean vectorIndexTypeSupported(DenseVectorFieldMapper.VectorIndexType type) {
5355
return type == DenseVectorFieldMapper.VectorIndexType.HNSW;
5456
}
57+
58+
private static KnnVectorsFormat getVectorsFormat(DenseVectorFieldMapper.DenseVectorIndexOptions indexOptions) {
59+
if (indexOptions.getType() == DenseVectorFieldMapper.VectorIndexType.HNSW) {
60+
DenseVectorFieldMapper.HnswIndexOptions hnswIndexOptions = (DenseVectorFieldMapper.HnswIndexOptions) indexOptions;
61+
int efConstruction = hnswIndexOptions.efConstruction();
62+
if (efConstruction == HnswGraphBuilder.DEFAULT_BEAM_WIDTH) {
63+
efConstruction = GPUVectorsFormat.DEFAULT_BEAM_WIDTH; // default value for GPU graph construction is 128
64+
}
65+
return new GPUVectorsFormat(hnswIndexOptions.m(), efConstruction);
66+
} else {
67+
throw new IllegalArgumentException(
68+
"GPU vector indexing is not supported on this vector type: [" + indexOptions.getType() + "]"
69+
);
70+
}
71+
}
5572
}

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ final class GPUToHNSWVectorsWriter extends KnnVectorsWriter {
8686
assert cuVSResourceManager != null : "CuVSResources must not be null";
8787
this.cuVSResourceManager = cuVSResourceManager;
8888
this.M = M;
89-
this.flatVectorWriter = flatVectorWriter;
9089
this.beamWidth = beamWidth;
90+
this.flatVectorWriter = flatVectorWriter;
9191
this.segmentWriteState = state;
9292
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, LUCENE99_HNSW_META_EXTENSION);
9393
String indexDataFileName = IndexFileNames.segmentFileName(
@@ -274,10 +274,11 @@ private CagraIndex buildGPUIndex(
274274
case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded;
275275
};
276276

277-
// TODO: expose cagra index params of intermediate graph degree, graph degree, algorithm, NNDescentNumIterations
277+
// TODO: expose cagra index params for algorithm, NNDescentNumIterations
278278
CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
279279
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
280-
.withGraphDegree(16)
280+
.withGraphDegree(M)
281+
.withIntermediateGraphDegree(beamWidth)
281282
.withMetric(distanceType)
282283
.build();
283284

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormat.java

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
1717
import org.apache.lucene.index.SegmentReadState;
1818
import org.apache.lucene.index.SegmentWriteState;
19-
import org.elasticsearch.logging.LogManager;
20-
import org.elasticsearch.logging.Logger;
2119

2220
import java.io.IOException;
2321

@@ -26,9 +24,6 @@
2624
* leverage GPU processing capabilities for vector search operations.
2725
*/
2826
public class GPUVectorsFormat extends KnnVectorsFormat {
29-
30-
private static final Logger LOG = LogManager.getLogger(GPUVectorsFormat.class);
31-
3227
public static final String NAME = "GPUVectorsFormat";
3328
public static final int VERSION_START = 0;
3429

@@ -38,34 +33,38 @@ public class GPUVectorsFormat extends KnnVectorsFormat {
3833
static final String LUCENE99_HNSW_VECTOR_INDEX_EXTENSION = "vex";
3934
static final int LUCENE99_VERSION_CURRENT = VERSION_START;
4035

41-
static final int DEFAULT_MAX_CONN = 16;
42-
static final int DEFAULT_BEAM_WIDTH = 100;
36+
static final int DEFAULT_MAX_CONN = 16; // graph degree
37+
public static final int DEFAULT_BEAM_WIDTH = 128; // intermediate graph degree
4338
static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = 2;
4439

4540
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(
4641
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
4742
);
4843

44+
// How many nodes each node in the graph is connected to in the final graph
45+
private final int maxConn;
46+
// Intermediate graph degree, the number of connections for each node before pruning
47+
private final int beamWidth;
4948
final CuVSResourceManager cuVSResourceManager;
5049

5150
public GPUVectorsFormat() {
52-
this(CuVSResourceManager.pooling());
51+
this(CuVSResourceManager.pooling(), DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
5352
}
5453

55-
public GPUVectorsFormat(CuVSResourceManager cuVSResourceManager) {
54+
public GPUVectorsFormat(int maxConn, int beamWidth) {
55+
this(CuVSResourceManager.pooling(), maxConn, beamWidth);
56+
};
57+
58+
public GPUVectorsFormat(CuVSResourceManager cuVSResourceManager, int maxConn, int beamWidth) {
5659
super(NAME);
5760
this.cuVSResourceManager = cuVSResourceManager;
61+
this.maxConn = maxConn;
62+
this.beamWidth = beamWidth;
5863
}
5964

6065
@Override
6166
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
62-
return new GPUToHNSWVectorsWriter(
63-
cuVSResourceManager,
64-
state,
65-
DEFAULT_MAX_CONN,
66-
DEFAULT_BEAM_WIDTH,
67-
flatVectorsFormat.fieldsWriter(state)
68-
);
67+
return new GPUToHNSWVectorsWriter(cuVSResourceManager, state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state));
6968
}
7069

7170
@Override
@@ -80,6 +79,6 @@ public int getMaxDimensions(String fieldName) {
8079

8180
@Override
8281
public String toString() {
83-
return NAME + "()";
82+
return NAME + "(maxConn=" + maxConn + ", beamWidth=" + beamWidth + ", flatVectorFormat=" + flatVectorsFormat.getName() + ")";
8483
}
8584
}

x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormatTests.java

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,14 @@ public void testMismatchedFields() throws Exception {
7676
// No bytes support
7777
}
7878

79-
@Override
80-
public void testSortedIndex() throws Exception {
81-
// TODO: implement sorted index
82-
}
83-
84-
@Override
85-
public void testFloatVectorScorerIteration() throws Exception {
86-
// TODO: implement sorted index
87-
}
88-
89-
@Override
90-
public void testRandom() throws Exception {
91-
// TODO: implement sorted index
92-
}
93-
9479
public void testToString() {
9580
FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) {
9681
@Override
9782
public KnnVectorsFormat knnVectorsFormat() {
9883
return new GPUVectorsFormat();
9984
}
10085
};
101-
String expectedPattern = "GPUVectorsFormat()";
86+
String expectedPattern = "GPUVectorsFormat(maxConn=16, beamWidth=128, flatVectorFormat=Lucene99FlatVectorsFormat)";
10287
assertEquals(expectedPattern, customCodec.knnVectorsFormat().toString());
10388
}
10489

0 commit comments

Comments
 (0)