Skip to content

Commit 3af34ab

Browse files
Iter1
- Created GPUVectorsFormat that write/read of flat vectors - Added a new index_options: gpu for dense_vector field that is under the feature flag
1 parent eeace58 commit 3af34ab

File tree

8 files changed

+660
-1
lines changed

8 files changed

+660
-1
lines changed

server/src/main/java/module-info.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@
457457
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
458458
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
459459
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
460-
org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
460+
org.elasticsearch.index.codec.vectors.IVFVectorsFormat,
461+
org.elasticsearch.index.codec.vectors.GPUVectorsFormat;
461462

462463
provides org.apache.lucene.codecs.Codec
463464
with
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.codecs.KnnVectorsFormat;
13+
import org.apache.lucene.codecs.KnnVectorsReader;
14+
import org.apache.lucene.codecs.KnnVectorsWriter;
15+
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
16+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
17+
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
18+
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
19+
import org.apache.lucene.index.SegmentReadState;
20+
import org.apache.lucene.index.SegmentWriteState;
21+
22+
import java.io.IOException;
23+
24+
/**
25+
* Codec format for GPU-accelerated vector indexes. This format is designed to
26+
* leverage GPU processing capabilities for vector search operations.
27+
*/
28+
public class GPUVectorsFormat extends KnnVectorsFormat {
29+
30+
public static final String NAME = "GPUVectorsFormat";
31+
public static final String GPU_IDX_EXTENSION = "gpuidx";
32+
public static final String GPU_META_EXTENSION = "mgpu";
33+
34+
public static final int VERSION_START = 0;
35+
public static final int VERSION_CURRENT = VERSION_START;
36+
37+
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
38+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
39+
);
40+
41+
public GPUVectorsFormat() {
42+
super(NAME);
43+
}
44+
45+
@Override
46+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
47+
return new GPUVectorsWriter(state, rawVectorFormat.fieldsWriter(state));
48+
}
49+
50+
@Override
51+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
52+
return new GPUVectorsReader(state, rawVectorFormat.fieldsReader(state));
53+
}
54+
55+
@Override
56+
public int getMaxDimensions(String fieldName) {
57+
return 4096;
58+
}
59+
60+
@Override
61+
public String toString() {
62+
return NAME + "()";
63+
}
64+
65+
static GPUVectorsReader getGPUReader(KnnVectorsReader vectorsReader, String fieldName) {
66+
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
67+
vectorsReader = candidateReader.getFieldReader(fieldName);
68+
}
69+
if (vectorsReader instanceof GPUVectorsReader reader) {
70+
return reader;
71+
}
72+
return null;
73+
}
74+
}
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.codecs.CodecUtil;
13+
import org.apache.lucene.codecs.KnnVectorsReader;
14+
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
15+
import org.apache.lucene.index.ByteVectorValues;
16+
import org.apache.lucene.index.CorruptIndexException;
17+
import org.apache.lucene.index.FieldInfo;
18+
import org.apache.lucene.index.FieldInfos;
19+
import org.apache.lucene.index.FloatVectorValues;
20+
import org.apache.lucene.index.IndexFileNames;
21+
import org.apache.lucene.index.SegmentReadState;
22+
import org.apache.lucene.index.VectorEncoding;
23+
import org.apache.lucene.index.VectorSimilarityFunction;
24+
import org.apache.lucene.internal.hppc.IntObjectHashMap;
25+
import org.apache.lucene.search.KnnCollector;
26+
import org.apache.lucene.store.ChecksumIndexInput;
27+
import org.apache.lucene.store.DataInput;
28+
import org.apache.lucene.store.IOContext;
29+
import org.apache.lucene.store.IndexInput;
30+
import org.apache.lucene.util.Bits;
31+
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
32+
import org.apache.lucene.util.hnsw.RandomVectorScorer;
33+
import org.elasticsearch.core.IOUtils;
34+
35+
import java.io.IOException;
36+
37+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
38+
39+
/**
40+
* Reader for GPU-accelerated vectors. This reader is used to read the GPU vectors from the index.
41+
*/
42+
public class GPUVectorsReader extends KnnVectorsReader {
43+
44+
private final IndexInput gpuIdx;
45+
private final SegmentReadState state;
46+
private final FieldInfos fieldInfos;
47+
protected final IntObjectHashMap<FieldEntry> fields;
48+
private final FlatVectorsReader rawVectorsReader;
49+
50+
@SuppressWarnings("this-escape")
51+
public GPUVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
52+
this.state = state;
53+
this.fieldInfos = state.fieldInfos;
54+
this.rawVectorsReader = rawVectorsReader;
55+
this.fields = new IntObjectHashMap<>();
56+
String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, GPUVectorsFormat.GPU_META_EXTENSION);
57+
58+
int versionMeta = -1;
59+
boolean success = false;
60+
try (ChecksumIndexInput gpuMeta = state.directory.openChecksumInput(meta)) {
61+
Throwable priorE = null;
62+
try {
63+
versionMeta = CodecUtil.checkIndexHeader(
64+
gpuMeta,
65+
GPUVectorsFormat.NAME,
66+
GPUVectorsFormat.VERSION_START,
67+
GPUVectorsFormat.VERSION_CURRENT,
68+
state.segmentInfo.getId(),
69+
state.segmentSuffix
70+
);
71+
readFields(gpuMeta);
72+
} catch (Throwable exception) {
73+
priorE = exception;
74+
} finally {
75+
CodecUtil.checkFooter(gpuMeta, priorE);
76+
}
77+
gpuIdx = openDataInput(state, versionMeta, GPUVectorsFormat.GPU_IDX_EXTENSION, GPUVectorsFormat.NAME, state.context);
78+
success = true;
79+
} finally {
80+
if (success == false) {
81+
IOUtils.closeWhileHandlingException(this);
82+
}
83+
}
84+
}
85+
86+
private static IndexInput openDataInput(
87+
SegmentReadState state,
88+
int versionMeta,
89+
String fileExtension,
90+
String codecName,
91+
IOContext context
92+
) throws IOException {
93+
final String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
94+
final IndexInput in = state.directory.openInput(fileName, context);
95+
boolean success = false;
96+
try {
97+
final int versionVectorData = CodecUtil.checkIndexHeader(
98+
in,
99+
codecName,
100+
GPUVectorsFormat.VERSION_START,
101+
GPUVectorsFormat.VERSION_CURRENT,
102+
state.segmentInfo.getId(),
103+
state.segmentSuffix
104+
);
105+
if (versionMeta != versionVectorData) {
106+
throw new CorruptIndexException(
107+
"Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData,
108+
in
109+
);
110+
}
111+
CodecUtil.retrieveChecksum(in);
112+
success = true;
113+
return in;
114+
} finally {
115+
if (success == false) {
116+
IOUtils.closeWhileHandlingException(in);
117+
}
118+
}
119+
}
120+
121+
private void readFields(ChecksumIndexInput meta) throws IOException {
122+
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
123+
final FieldInfo info = fieldInfos.fieldInfo(fieldNumber);
124+
if (info == null) {
125+
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
126+
}
127+
fields.put(info.number, readField(meta, info));
128+
}
129+
}
130+
131+
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
132+
final VectorEncoding vectorEncoding = readVectorEncoding(input);
133+
final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
134+
final long dataOffset = input.readLong();
135+
final long dataLength = input.readLong();
136+
137+
if (similarityFunction != info.getVectorSimilarityFunction()) {
138+
throw new IllegalStateException(
139+
"Inconsistent vector similarity function for field=\""
140+
+ info.name
141+
+ "\"; "
142+
+ similarityFunction
143+
+ " != "
144+
+ info.getVectorSimilarityFunction()
145+
);
146+
}
147+
return new FieldEntry(similarityFunction, vectorEncoding, dataOffset, dataLength);
148+
}
149+
150+
private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
151+
final int i = input.readInt();
152+
if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
153+
throw new IllegalArgumentException("invalid distance function: " + i);
154+
}
155+
return SIMILARITY_FUNCTIONS.get(i);
156+
}
157+
158+
private static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
159+
final int encodingId = input.readInt();
160+
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
161+
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
162+
}
163+
return VectorEncoding.values()[encodingId];
164+
}
165+
166+
@Override
167+
public final void checkIntegrity() throws IOException {
168+
rawVectorsReader.checkIntegrity();
169+
CodecUtil.checksumEntireFile(gpuIdx);
170+
}
171+
172+
@Override
173+
public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
174+
return rawVectorsReader.getFloatVectorValues(field);
175+
}
176+
177+
@Override
178+
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
179+
return rawVectorsReader.getByteVectorValues(field);
180+
}
181+
182+
@Override
183+
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
184+
// TODO: Implement GPU-accelerated search
185+
collectAllMatchingDocs(knnCollector, acceptDocs, rawVectorsReader.getRandomVectorScorer(field, target));
186+
}
187+
188+
private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs, RandomVectorScorer scorer) throws IOException {
189+
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
190+
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
191+
for (int i = 0; i < scorer.maxOrd(); i++) {
192+
if (acceptedOrds == null || acceptedOrds.get(i)) {
193+
collector.collect(i, scorer.score(i));
194+
collector.incVisitedCount(1);
195+
}
196+
}
197+
assert collector.earlyTerminated() == false;
198+
}
199+
200+
@Override
201+
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
202+
collectAllMatchingDocs(knnCollector, acceptDocs, rawVectorsReader.getRandomVectorScorer(field, target));
203+
}
204+
205+
@Override
206+
public void close() throws IOException {
207+
IOUtils.close(rawVectorsReader, gpuIdx);
208+
}
209+
210+
protected record FieldEntry(
211+
VectorSimilarityFunction similarityFunction,
212+
VectorEncoding vectorEncoding,
213+
long dataOffset,
214+
long dataLength
215+
) {
216+
IndexInput dataSlice(IndexInput dataFile) throws IOException {
217+
return dataFile.slice("gpu-data", dataOffset, dataLength);
218+
}
219+
}
220+
}

0 commit comments

Comments
 (0)