Skip to content

Commit 210cc1f

Browse files
committed
Include Cagra Params (algo + parameters) in occupancy estimation
1 parent 5050166 commit 210cc1f

File tree

3 files changed

+143
-109
lines changed

3 files changed

+143
-109
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,8 @@ public interface CuVSResourceManager {
4747
* effect on GPU memory and compute usage to determine whether to give out
4848
* another resource or wait for a resources to be returned before giving out another.
4949
*/
50-
ManagedCuVSResources acquire(
51-
int numVectors,
52-
int dims,
53-
CuVSMatrix.DataType dataType,
54-
CagraIndexParams.CagraGraphBuildAlgo graphBuildAlgo
55-
) throws InterruptedException;
50+
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams)
51+
throws InterruptedException;
5652

5753
/** Marks the resources as finished with regard to compute. */
5854
void finishedComputation(ManagedCuVSResources resources);
@@ -130,20 +126,16 @@ private int numLockedResources() {
130126
}
131127

132128
@Override
133-
public ManagedCuVSResources acquire(
134-
int numVectors,
135-
int dims,
136-
CuVSMatrix.DataType dataType,
137-
CagraIndexParams.CagraGraphBuildAlgo graphBuildAlgo
138-
) throws InterruptedException {
129+
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams)
130+
throws InterruptedException {
139131
try {
140132
var started = System.nanoTime();
141133
lock.lock();
142134

143135
boolean allConditionsMet = false;
144136
ManagedCuVSResources res = null;
145137

146-
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType, graphBuildAlgo);
138+
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType, cagraIndexParams);
147139
logger.debug(
148140
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
149141
numVectors,
@@ -200,17 +192,25 @@ public ManagedCuVSResources acquire(
200192
}
201193
}
202194

203-
private long estimateRequiredMemory(
204-
int numVectors,
205-
int dims,
206-
CuVSMatrix.DataType dataType,
207-
CagraIndexParams.CagraGraphBuildAlgo graphBuildAlgo
208-
) {
195+
private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) {
209196
int elementTypeBytes = switch (dataType) {
210197
case FLOAT -> Float.BYTES;
211198
case INT, UINT -> Integer.BYTES;
212199
case BYTE -> Byte.BYTES;
213200
};
201+
202+
if (cagraIndexParams.getCagraGraphBuildAlgo() == CagraIndexParams.CagraGraphBuildAlgo.IVF_PQ
203+
&& cagraIndexParams.getCuVSIvfPqParams() != null
204+
&& cagraIndexParams.getCuVSIvfPqParams().getIndexParams() != null
205+
&& cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim() != 0) {
206+
// See https://docs.rapids.ai/api/cuvs/nightly/neighbors/ivfpq/#index-device-memory
207+
var pqDim = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim();
208+
var pqBits = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqBits();
209+
var numClusters = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getnLists();
210+
var approximatedIvfBytes = numVectors * (pqDim * (pqBits / 8.0) + elementTypeBytes) + (long) numClusters * Integer.BYTES;
211+
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * approximatedIvfBytes);
212+
}
213+
214214
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes);
215215
}
216216

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java

Lines changed: 50 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
178178
var started = System.nanoTime();
179179
var fieldInfo = field.fieldInfo;
180180

181+
CagraIndexParams cagraIndexParams = createCagraIndexParams(fieldInfo.getVectorSimilarityFunction());
182+
181183
var numVectors = field.flatFieldVectorsWriter.getVectors().size();
182184
if (numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
183185
if (logger.isDebugEnabled()) {
@@ -193,12 +195,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
193195
try (
194196
var resourcesHolder = new ResourcesHolder(
195197
cuVSResourceManager,
196-
cuVSResourceManager.acquire(
197-
numVectors,
198-
fieldInfo.getVectorDimension(),
199-
CuVSMatrix.DataType.FLOAT,
200-
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
201-
)
198+
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), CuVSMatrix.DataType.FLOAT, cagraIndexParams)
202199
)
203200
) {
204201
var builder = CuVSMatrix.deviceBuilder(
@@ -211,7 +208,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
211208
builder.addVector(vector);
212209
}
213210
try (var dataset = builder.build()) {
214-
flushFieldWithGpuGraph(resourcesHolder, fieldInfo, dataset, sortMap);
211+
flushFieldWithGpuGraph(resourcesHolder, fieldInfo, dataset, sortMap, cagraIndexParams);
215212
}
216213
}
217214
}
@@ -229,13 +226,18 @@ private void flushFieldWithMockGraph(FieldInfo fieldInfo, int numVectors, Sorter
229226
}
230227
}
231228

232-
private void flushFieldWithGpuGraph(ResourcesHolder resourcesHolder, FieldInfo fieldInfo, CuVSMatrix dataset, Sorter.DocMap sortMap)
233-
throws IOException {
229+
private void flushFieldWithGpuGraph(
230+
ResourcesHolder resourcesHolder,
231+
FieldInfo fieldInfo,
232+
CuVSMatrix dataset,
233+
Sorter.DocMap sortMap,
234+
CagraIndexParams cagraIndexParams
235+
) throws IOException {
234236
if (sortMap == null) {
235-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
237+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
236238
} else {
237239
// TODO: use sortMap
238-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
240+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
239241
}
240242
}
241243

@@ -267,14 +269,19 @@ public long ramBytesUsed() {
267269
return total;
268270
}
269271

270-
private void generateGpuGraphAndWriteMeta(ResourcesHolder resourcesHolder, FieldInfo fieldInfo, CuVSMatrix dataset) throws IOException {
272+
private void generateGpuGraphAndWriteMeta(
273+
ResourcesHolder resourcesHolder,
274+
FieldInfo fieldInfo,
275+
CuVSMatrix dataset,
276+
CagraIndexParams cagraIndexParams
277+
) throws IOException {
271278
try {
272279
assert dataset.size() >= MIN_NUM_VECTORS_FOR_GPU_BUILD;
273280

274281
long vectorIndexOffset = vectorIndex.getFilePointer();
275282
int[][] graphLevelNodeOffsets = new int[1][];
276283
final HnswGraph graph;
277-
try (var index = buildGPUIndex(resourcesHolder.resources(), fieldInfo.getVectorSimilarityFunction(), dataset)) {
284+
try (var index = buildGPUIndex(resourcesHolder.resources(), cagraIndexParams, dataset)) {
278285
assert index != null : "GPU index should be built for field: " + fieldInfo.name;
279286
var deviceGraph = index.getGraph();
280287
var graphSize = deviceGraph.size() * deviceGraph.columns() * Integer.BYTES;
@@ -314,9 +321,20 @@ private void generateMockGraphAndWriteMeta(FieldInfo fieldInfo, int datasetSize)
314321

315322
private CagraIndex buildGPUIndex(
316323
CuVSResourceManager.ManagedCuVSResources cuVSResources,
317-
VectorSimilarityFunction similarityFunction,
324+
CagraIndexParams cagraIndexParams,
318325
CuVSMatrix dataset
319326
) throws Throwable {
327+
long startTime = System.nanoTime();
328+
var indexBuilder = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(cagraIndexParams);
329+
var index = indexBuilder.build();
330+
cuVSResourceManager.finishedComputation(cuVSResources);
331+
if (logger.isDebugEnabled()) {
332+
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, dataset.size());
333+
}
334+
return index;
335+
}
336+
337+
private CagraIndexParams createCagraIndexParams(VectorSimilarityFunction similarityFunction) {
320338
CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
321339
case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded;
322340
case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
@@ -333,22 +351,13 @@ private CagraIndex buildGPUIndex(
333351
};
334352

335353
// TODO: expose cagra index params for algorithm, NNDescentNumIterations
336-
CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
354+
return new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
337355
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
338356
.withGraphDegree(M)
339357
.withIntermediateGraphDegree(beamWidth)
340358
.withNNDescentNumIterations(5)
341359
.withMetric(distanceType)
342360
.build();
343-
344-
long startTime = System.nanoTime();
345-
var indexBuilder = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(params);
346-
var index = indexBuilder.build();
347-
cuVSResourceManager.finishedComputation(cuVSResources);
348-
if (logger.isDebugEnabled()) {
349-
logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, dataset.size());
350-
}
351-
return index;
352361
}
353362

354363
private HnswGraph writeGraph(CuVSMatrix cagraGraph, int[][] levelNodeOffsets) throws IOException {
@@ -505,6 +514,9 @@ private void mergeByteVectorField(
505514
var vectorValues = randomScorerSupplier == null
506515
? null
507516
: VectorsFormatReflectionUtils.getByteScoringSupplierVectorOrNull(randomScorerSupplier);
517+
518+
CagraIndexParams cagraIndexParams = createCagraIndexParams(fieldInfo.getVectorSimilarityFunction());
519+
508520
if (vectorValues != null) {
509521
IndexInput slice = vectorValues.getSlice();
510522
var input = FilterIndexInput.unwrapOnlyTest(slice);
@@ -538,15 +550,10 @@ private void mergeByteVectorField(
538550
var dataset = DatasetUtilsImpl.fromMemorySegment(packedSegment, numVectors, packedRowSize, dataType);
539551
var resourcesHolder = new ResourcesHolder(
540552
cuVSResourceManager,
541-
cuVSResourceManager.acquire(
542-
numVectors,
543-
fieldInfo.getVectorDimension(),
544-
dataType,
545-
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
546-
)
553+
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
547554
)
548555
) {
549-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
556+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
550557
}
551558
}
552559
} else {
@@ -567,15 +574,10 @@ private void mergeByteVectorField(
567574
var dataset = builder.build();
568575
var resourcesHolder = new ResourcesHolder(
569576
cuVSResourceManager,
570-
cuVSResourceManager.acquire(
571-
numVectors,
572-
fieldInfo.getVectorDimension(),
573-
dataType,
574-
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
575-
)
577+
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
576578
)
577579
) {
578-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
580+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
579581
}
580582
}
581583
} else {
@@ -593,15 +595,10 @@ private void mergeByteVectorField(
593595
var dataset = builder.build();
594596
var resourcesHolder = new ResourcesHolder(
595597
cuVSResourceManager,
596-
cuVSResourceManager.acquire(
597-
numVectors,
598-
fieldInfo.getVectorDimension(),
599-
dataType,
600-
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
601-
)
598+
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
602599
)
603600
) {
604-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
601+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
605602
}
606603
}
607604
}
@@ -615,6 +612,8 @@ private void mergeFloatVectorField(
615612
var vectorValues = randomScorerSupplier == null
616613
? null
617614
: VectorsFormatReflectionUtils.getFloatScoringSupplierVectorOrNull(randomScorerSupplier);
615+
CagraIndexParams cagraIndexParams = createCagraIndexParams(fieldInfo.getVectorSimilarityFunction());
616+
618617
if (vectorValues != null) {
619618
IndexInput slice = vectorValues.getSlice();
620619
var input = FilterIndexInput.unwrapOnlyTest(slice);
@@ -625,15 +624,10 @@ private void mergeFloatVectorField(
625624
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
626625
var resourcesHolder = new ResourcesHolder(
627626
cuVSResourceManager,
628-
cuVSResourceManager.acquire(
629-
numVectors,
630-
fieldInfo.getVectorDimension(),
631-
dataType,
632-
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
633-
)
627+
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
634628
)
635629
) {
636-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
630+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
637631
}
638632
} else {
639633
logger.info(
@@ -653,15 +647,10 @@ private void mergeFloatVectorField(
653647
var dataset = builder.build();
654648
var resourcesHolder = new ResourcesHolder(
655649
cuVSResourceManager,
656-
cuVSResourceManager.acquire(
657-
numVectors,
658-
fieldInfo.getVectorDimension(),
659-
dataType,
660-
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
661-
)
650+
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
662651
)
663652
) {
664-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
653+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
665654
}
666655
}
667656
} else {
@@ -680,15 +669,10 @@ private void mergeFloatVectorField(
680669
var dataset = builder.build();
681670
var resourcesHolder = new ResourcesHolder(
682671
cuVSResourceManager,
683-
cuVSResourceManager.acquire(
684-
numVectors,
685-
fieldInfo.getVectorDimension(),
686-
dataType,
687-
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
688-
)
672+
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams)
689673
)
690674
) {
691-
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
675+
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams);
692676
}
693677
}
694678
}

0 commit comments

Comments
 (0)