Skip to content

Commit d27b62f

Browse files
committed
Use off-heap Dataset when merging vector data
1 parent 0ee27d9 commit d27b62f

File tree

6 files changed

+230
-12
lines changed

6 files changed

+230
-12
lines changed

x-pack/plugin/gpu/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
apply plugin: 'elasticsearch.internal-es-plugin'
22
apply plugin: 'elasticsearch.internal-cluster-test'
33
apply plugin: 'elasticsearch.internal-yaml-rest-test'
4+
apply plugin: 'elasticsearch.mrjar'
45

56
esplugin {
67
name = 'gpu'
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import com.nvidia.cuvs.Dataset;
11+
12+
import org.apache.lucene.store.MemorySegmentAccessInput;
13+
14+
import java.io.IOException;
15+
16+
public interface DatasetUtils {
17+
18+
static DatasetUtils getInstance() {
19+
return DatasetUtilsImpl.getInstance();
20+
}
21+
22+
/** Returns a Dataset over the float32 vectors in the input. */
23+
Dataset fromInput(MemorySegmentAccessInput input, int numVectors, int dims) throws IOException;
24+
25+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import com.nvidia.cuvs.Dataset;
11+
12+
import org.apache.lucene.store.MemorySegmentAccessInput;
13+
14+
import java.io.IOException;
15+
16+
/** Stubb holder - never executed. */
17+
public class DatasetUtilsImpl implements DatasetUtils {
18+
19+
static DatasetUtils getInstance() {
20+
throw new UnsupportedOperationException("should not reach here");
21+
}
22+
23+
@Override
24+
public Dataset fromInput(MemorySegmentAccessInput input, int numVectors, int dims) throws IOException {
25+
throw new UnsupportedOperationException("should not reach here");
26+
}
27+
}

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
import org.apache.lucene.index.Sorter;
2828
import org.apache.lucene.index.VectorEncoding;
2929
import org.apache.lucene.index.VectorSimilarityFunction;
30+
import org.apache.lucene.store.FilterIndexInput;
3031
import org.apache.lucene.store.IOContext;
3132
import org.apache.lucene.store.IndexInput;
3233
import org.apache.lucene.store.IndexOutput;
34+
import org.apache.lucene.store.MemorySegmentAccessInput;
3335
import org.apache.lucene.util.RamUsageEstimator;
3436
import org.apache.lucene.util.hnsw.HnswGraph;
3537
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
@@ -175,12 +177,15 @@ private static final class DatasetOrVectors {
175177
private final Dataset dataset;
176178
private final float[][] vectors;
177179

178-
DatasetOrVectors(float[][] vectors) {
179-
this(
180+
static DatasetOrVectors fromArray(float[][] vectors) {
181+
return new DatasetOrVectors(
180182
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? null : Dataset.ofArray(vectors),
181183
vectors.length < MIN_NUM_VECTORS_FOR_GPU_BUILD ? vectors : null
182184
);
183-
validateState();
185+
}
186+
187+
static DatasetOrVectors fromDataset(Dataset dataset) {
188+
return new DatasetOrVectors(dataset, null);
184189
}
185190

186191
private DatasetOrVectors(Dataset dataset, float[][] vectors) {
@@ -210,7 +215,7 @@ float[][] getVectors() {
210215

211216
private void writeField(FieldWriter fieldWriter) throws IOException {
212217
float[][] vectors = fieldWriter.flatFieldVectorsWriter.getVectors().toArray(float[][]::new);
213-
writeFieldInternal(fieldWriter.fieldInfo, new DatasetOrVectors(vectors));
218+
writeFieldInternal(fieldWriter.fieldInfo, DatasetOrVectors.fromArray(vectors));
214219
}
215220

216221
private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException {
@@ -481,21 +486,33 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
481486
}
482487
}
483488
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
484-
// TODO: Improve this (not acceptable): pass tempRawVectorsFileName for the gpuIndex construction through MemorySegment
485-
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
486-
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
487-
float[] vector;
488-
for (int i = 0; i < numVectors; i++) {
489-
vector = floatVectorValues.vectorValue(i);
490-
System.arraycopy(vector, 0, vectors[i], 0, vector.length);
489+
DatasetOrVectors datasetOrVectors;
490+
491+
var input = FilterIndexInput.unwrapOnlyTest(in);
492+
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) {
493+
var ds = DatasetUtils.getInstance().fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension());
494+
datasetOrVectors = DatasetOrVectors.fromDataset(ds);
495+
} else {
496+
var fa = copyVectorsIntoArray(in, fieldInfo, numVectors);
497+
datasetOrVectors = DatasetOrVectors.fromArray(fa);
491498
}
492-
DatasetOrVectors datasetOrVectors = new DatasetOrVectors(vectors);
493499
writeFieldInternal(fieldInfo, datasetOrVectors);
494500
} finally {
495501
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
496502
}
497503
}
498504

505+
static float[][] copyVectorsIntoArray(IndexInput in, FieldInfo fieldInfo, int numVectors) throws IOException {
506+
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
507+
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
508+
float[] vector;
509+
for (int i = 0; i < numVectors; i++) {
510+
vector = floatVectorValues.vectorValue(i);
511+
System.arraycopy(vector, 0, vectors[i], 0, vector.length);
512+
}
513+
return vectors;
514+
}
515+
499516
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
500517
throws IOException {
501518
int numVectors = 0;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import com.nvidia.cuvs.Dataset;
11+
import com.nvidia.cuvs.spi.CuVSProvider;
12+
13+
import org.apache.lucene.store.MemorySegmentAccessInput;
14+
15+
import java.io.IOException;
16+
import java.lang.foreign.MemorySegment;
17+
import java.lang.invoke.MethodHandle;
18+
19+
public class DatasetUtilsImpl implements DatasetUtils {
20+
21+
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();
22+
23+
private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeDatasetBuilder();
24+
25+
static DatasetUtils getInstance() {
26+
return INSTANCE;
27+
}
28+
29+
static Dataset fromMemorySegment(MemorySegment memorySegment, int size, int dimensions) {
30+
try {
31+
return (Dataset) createDataset$mh.invokeExact(memorySegment, size, dimensions);
32+
} catch (Throwable e) {
33+
if (e instanceof Error err) {
34+
throw err;
35+
} else if (e instanceof RuntimeException re) {
36+
throw re;
37+
} else {
38+
throw new RuntimeException(e);
39+
}
40+
}
41+
}
42+
43+
private DatasetUtilsImpl() {}
44+
45+
@Override
46+
public Dataset fromInput(MemorySegmentAccessInput input, int numVectors, int dims) throws IOException {
47+
if (numVectors < 0 || dims < 0) {
48+
throwIllegalArgumentException(numVectors, dims);
49+
}
50+
MemorySegment ms = input.segmentSliceOrNull(0L, input.length());
51+
assert ms != null; // TODO: this can be null if larger than 16GB or ...
52+
if (((long) numVectors * dims * Float.BYTES) < ms.byteSize()) {
53+
throwIllegalArgumentException(ms, numVectors, dims);
54+
}
55+
return fromMemorySegment(ms, numVectors, dims);
56+
}
57+
58+
static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
59+
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + "dimensions";
60+
throw new IllegalArgumentException(s);
61+
}
62+
63+
static void throwIllegalArgumentException(int numVectors, int dims) {
64+
String s;
65+
if (numVectors < 0) {
66+
s = "negative number of vectors:" + numVectors;
67+
} else {
68+
s = "negative vector dims:" + dims;
69+
}
70+
throw new IllegalArgumentException(s);
71+
}
72+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import org.apache.lucene.store.Directory;
11+
import org.apache.lucene.store.IOContext;
12+
import org.apache.lucene.store.MMapDirectory;
13+
import org.apache.lucene.store.MemorySegmentAccessInput;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.junit.Before;
16+
17+
import java.lang.foreign.MemorySegment;
18+
import java.lang.foreign.ValueLayout;
19+
import java.nio.ByteOrder;
20+
21+
import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED;
22+
23+
public class DatasetUtilsTests extends ESTestCase {
24+
25+
@Before
26+
public void setup() { // TODO: abstract out setup in to common GPUTestcase
27+
assumeTrue("cuvs runtime only supported on 22 or greater, your JDK is " + Runtime.version(), Runtime.version().feature() >= 22);
28+
try (var resources = GPUVectorsFormat.cuVSResourcesOrNull()) {
29+
assumeTrue("cuvs not supported", resources != null);
30+
}
31+
}
32+
33+
static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
34+
35+
final DatasetUtils datasetUtils = DatasetUtils.getInstance();
36+
37+
public void testBasic() throws Exception {
38+
try (Directory dir = new MMapDirectory(createTempDir("testBasic"))) {
39+
int numVecs = randomIntBetween(1, 100);
40+
int dims = randomIntBetween(128, 2049);
41+
42+
try (var out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
43+
var ba = new byte[dims * Float.BYTES];
44+
var seg = MemorySegment.ofArray(ba);
45+
for (int v = 0; v < numVecs; v++) {
46+
var src = MemorySegment.ofArray(randomVector(dims));
47+
MemorySegment.copy(src, JAVA_FLOAT_UNALIGNED, 0L, seg, JAVA_FLOAT_LE, 0L, numVecs);
48+
out.writeBytes(ba, 0, ba.length);
49+
}
50+
}
51+
try (
52+
var in = dir.openInput("vector.data", IOContext.DEFAULT);
53+
var dataset = datasetUtils.fromInput((MemorySegmentAccessInput) in, numVecs, dims)
54+
) {
55+
assertEquals(numVecs, dataset.size());
56+
assertEquals(dims, dataset.dimensions());
57+
}
58+
}
59+
}
60+
61+
static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
62+
63+
public void testIllegal() {
64+
MemorySegmentAccessInput in = null; // TODO: make this non-null
65+
expectThrows(IAE, () -> datasetUtils.fromInput(in, -1, 1));
66+
expectThrows(IAE, () -> datasetUtils.fromInput(in, 1, -1));
67+
}
68+
69+
float[] randomVector(int dims) {
70+
float[] fa = new float[dims];
71+
for (int i = 0; i < dims; ++i) {
72+
fa[i] = random().nextFloat();
73+
}
74+
return fa;
75+
}
76+
}

0 commit comments

Comments
 (0)