Skip to content

Commit cff3e94

Browse files
Send quantized data to GPU for index building during merge
1 parent 1a727fa commit cff3e94

File tree

7 files changed

+415
-12
lines changed

7 files changed

+415
-12
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException
132132
);
133133
}
134134

135-
static final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
135+
public static final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
136136

137137
final Lucene99ScalarQuantizedVectorsWriter delegate;
138138

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ static DatasetUtils getInstance() {
2020
}
2121

2222
/** Returns a Dataset over the float32 vectors in the input. */
23-
CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims) throws IOException;
23+
CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException;
2424
}

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ static DatasetUtils getInstance() {
1919
}
2020

2121
@Override
22-
public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims) {
22+
public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) {
2323
throw new UnsupportedOperationException("should not reach here");
2424
}
2525
}

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.lucene.codecs.KnnVectorsWriter;
1717
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
1818
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
19+
import org.apache.lucene.index.ByteVectorValues;
1920
import org.apache.lucene.index.DocsWithFieldSet;
2021
import org.apache.lucene.index.FieldInfo;
2122
import org.apache.lucene.index.FloatVectorValues;
@@ -35,8 +36,10 @@
3536
import org.apache.lucene.util.hnsw.HnswGraph;
3637
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
3738
import org.apache.lucene.util.packed.DirectMonotonicWriter;
39+
import org.apache.lucene.util.quantization.ScalarQuantizer;
3840
import org.elasticsearch.core.IOUtils;
3941
import org.elasticsearch.core.SuppressForbidden;
42+
import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat;
4043
import org.elasticsearch.logging.LogManager;
4144
import org.elasticsearch.logging.Logger;
4245

@@ -49,6 +52,7 @@
4952
import java.util.Objects;
5053

5154
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
55+
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter.mergeAndRecalculateQuantiles;
5256
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
5357
import static org.elasticsearch.xpack.gpu.codec.ESGpuHnswVectorsFormat.LUCENE99_HNSW_META_CODEC_NAME;
5458
import static org.elasticsearch.xpack.gpu.codec.ESGpuHnswVectorsFormat.LUCENE99_HNSW_META_EXTENSION;
@@ -75,6 +79,7 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
7579

7680
private final List<FieldWriter> fields = new ArrayList<>();
7781
private boolean finished;
82+
private final CuVSMatrix.DataType dataType;
7883

7984
ESGpuHnswVectorsWriter(
8085
CuVSResourceManager cuVSResourceManager,
@@ -88,6 +93,11 @@ final class ESGpuHnswVectorsWriter extends KnnVectorsWriter {
8893
this.M = M;
8994
this.beamWidth = beamWidth;
9095
this.flatVectorWriter = flatVectorWriter;
96+
if (flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter) {
97+
dataType = CuVSMatrix.DataType.BYTE;
98+
} else {
99+
dataType = CuVSMatrix.DataType.FLOAT;
100+
}
91101
this.segmentWriteState = state;
92102
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, LUCENE99_HNSW_META_EXTENSION);
93103
String indexDataFileName = IndexFileNames.segmentFileName(
@@ -411,13 +421,17 @@ public NodesIterator getNodesOnLevel(int level) {
411421
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
412422
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
413423
flatVectorWriter.mergeOneField(fieldInfo, mergeState);
414-
// save merged vector values to a temp file
415424
final int numVectors;
416425
String tempRawVectorsFileName = null;
417426
boolean success = false;
427+
// save merged vector values to a temp file
418428
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "vec_", IOContext.DEFAULT)) {
419429
tempRawVectorsFileName = out.getName();
420-
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
430+
if (dataType == CuVSMatrix.DataType.BYTE) {
431+
numVectors = writeByteVectorValues(out, getMergedByteVectorValues(fieldInfo, mergeState));
432+
} else {
433+
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
434+
}
421435
CodecUtil.writeFooter(out);
422436
success = true;
423437
} finally {
@@ -429,9 +443,11 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
429443
DatasetOrVectors datasetOrVectors;
430444
var input = FilterIndexInput.unwrapOnlyTest(in);
431445
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
432-
var ds = DatasetUtils.getInstance().fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension());
446+
var ds = DatasetUtils.getInstance()
447+
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
433448
datasetOrVectors = DatasetOrVectors.fromDataset(ds);
434449
} else {
450+
// TODO fix for byte vectors
435451
var fa = copyVectorsIntoArray(in, fieldInfo, numVectors);
436452
datasetOrVectors = DatasetOrVectors.fromArray(fa);
437453
}
@@ -441,6 +457,31 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
441457
}
442458
}
443459

460+
private ByteVectorValues getMergedByteVectorValues(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
461+
// TODO: expose confidence interval from the format
462+
final byte bits = 7;
463+
final Float confidenceInterval = null;
464+
ScalarQuantizer quantizer = mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
465+
MergedQuantizedVectorValues byteVectorValues = MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
466+
fieldInfo,
467+
mergeState,
468+
quantizer
469+
);
470+
return byteVectorValues;
471+
}
472+
473+
private static int writeByteVectorValues(IndexOutput out, ByteVectorValues vectorValues) throws IOException {
474+
int numVectors = 0;
475+
byte[] vector;
476+
final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
477+
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
478+
numVectors++;
479+
vector = vectorValues.vectorValue(iterator.index());
480+
out.writeBytes(vector, vector.length);
481+
}
482+
return numVectors;
483+
}
484+
444485
static float[][] copyVectorsIntoArray(IndexInput in, FieldInfo fieldInfo, int numVectors) throws IOException {
445486
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
446487
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];

0 commit comments

Comments
 (0)