diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java index b57d933d77c4b..64e2f920d2c0c 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java @@ -7,9 +7,9 @@ package org.elasticsearch.xpack.gpu.codec; +import com.nvidia.cuvs.CagraIndexParams; import com.nvidia.cuvs.CuVSMatrix; import com.nvidia.cuvs.CuVSResources; -import com.nvidia.cuvs.GPUInfoProvider; import com.nvidia.cuvs.spi.CuVSProvider; import org.elasticsearch.core.Strings; @@ -47,10 +47,8 @@ public interface CuVSResourceManager { * effect on GPU memory and compute usage to determine whether to give out * another resource or wait for a resources to be returned before giving out another. */ - // numVectors and dims are currently unused, but could be used along with GPU metadata, - // memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims, - // to give out a resources or not. - ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException; + ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) + throws InterruptedException; /** Marks the resources as finished with regard to compute. */ void finishedComputation(ManagedCuVSResources resources); @@ -80,31 +78,31 @@ class PoolingCuVSResourceManager implements CuVSResourceManager { static class Holder { static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager( MAX_RESOURCES, - CuVSProvider.provider().gpuInfoProvider() + new RealGPUMemoryService(CuVSProvider.provider().gpuInfoProvider()) ); } private final ManagedCuVSResources[] pool; private final int capacity; - private final GPUInfoProvider gpuInfoProvider; + private final GPUMemoryService gpuMemoryService; private int createdCount; ReentrantLock lock = new ReentrantLock(); Condition enoughResourcesCondition = lock.newCondition(); - public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) { + PoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) { if (capacity < 1 || capacity > MAX_RESOURCES) { throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES); } this.capacity = capacity; - this.gpuInfoProvider = gpuInfoProvider; + this.gpuMemoryService = gpuMemoryService; this.pool = new ManagedCuVSResources[MAX_RESOURCES]; } private ManagedCuVSResources getResourceFromPool() { for (int i = 0; i < createdCount; ++i) { var res = pool[i]; - if (res.locked == false) { + if (res.isLocked() == false) { return res; } } @@ -120,7 +118,7 @@ private int numLockedResources() { int lockedResources = 0; for (int i = 0; i < createdCount; ++i) { var res = pool[i]; - if (res.locked) { + if (res.isLocked()) { lockedResources++; } } @@ -128,35 +126,37 @@ private int numLockedResources() { } @Override - public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException { + public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) + throws InterruptedException { try { var started = System.nanoTime(); lock.lock(); boolean allConditionsMet = false; ManagedCuVSResources res = null; + + long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType, cagraIndexParams); + logger.debug( + "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]", + numVectors, + dims, + dataType.name(), + requiredMemoryInBytes + ); + while (allConditionsMet == false) { res = getResourceFromPool(); final boolean enoughMemory; if (res != null) { - long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType); - logger.debug( - "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]", - numVectors, - dims, - dataType.name(), - requiredMemoryInBytes - ); - // Check immutable constraints - long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes(); - if (requiredMemoryInBytes > totalDeviceMemoryInBytes) { + long totalMemoryInBytes = gpuMemoryService.totalMemoryInBytes(res); + if (requiredMemoryInBytes > totalMemoryInBytes) { String message = Strings.format( "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]", numVectors, dims, - totalDeviceMemoryInBytes + totalMemoryInBytes ); logger.error(message); throw new IllegalArgumentException(message); @@ -169,9 +169,9 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp } // Check resources availability - long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes(); - enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes; - logger.debug("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes, enoughMemory); + long availableMemoryInBytes = gpuMemoryService.availableMemoryInBytes(res); + enoughMemory = requiredMemoryInBytes <= availableMemoryInBytes; + logger.debug("Free device memory [{} B], enoughMemory[{}]", availableMemoryInBytes, enoughMemory); } else { logger.debug("No resources available in pool"); enoughMemory = false; @@ -184,19 +184,33 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp } var elapsed = started - System.nanoTime(); logger.debug("Resource acquired in [{}ms]", elapsed / 1_000_000.0); - res.locked = true; + gpuMemoryService.reserveMemory(requiredMemoryInBytes); + res.lock(() -> gpuMemoryService.releaseMemory(requiredMemoryInBytes)); return res; } finally { lock.unlock(); } } - private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) { + private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) { int elementTypeBytes = switch (dataType) { case FLOAT -> Float.BYTES; case INT, UINT -> Integer.BYTES; case BYTE -> Byte.BYTES; }; + + if (cagraIndexParams.getCagraGraphBuildAlgo() == CagraIndexParams.CagraGraphBuildAlgo.IVF_PQ + && cagraIndexParams.getCuVSIvfPqParams() != null + && cagraIndexParams.getCuVSIvfPqParams().getIndexParams() != null + && cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim() != 0) { + // See https://docs.rapids.ai/api/cuvs/nightly/neighbors/ivfpq/#index-device-memory + var pqDim = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim(); + var pqBits = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqBits(); + var numClusters = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getnLists(); + var approximatedIvfBytes = numVectors * (pqDim * (pqBits / 8.0) + elementTypeBytes) + (long) numClusters * Integer.BYTES; + return (long) (GPU_COMPUTATION_MEMORY_FACTOR * approximatedIvfBytes); + } + return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes); } @@ -217,8 +231,8 @@ public void release(ManagedCuVSResources resources) { logger.debug("Releasing resources to pool"); try { lock.lock(); - assert resources.locked; - resources.locked = false; + assert resources.isLocked(); + resources.unlock(); enoughResourcesCondition.signalAll(); } finally { lock.unlock(); @@ -238,8 +252,9 @@ public void shutdown() { /** A managed resource. Cannot be closed. */ final class ManagedCuVSResources implements CuVSResources { - final CuVSResources delegate; - boolean locked = false; + private final CuVSResources delegate; + private static final Runnable NOT_LOCKED = () -> {}; + private Runnable unlockAction = NOT_LOCKED; ManagedCuVSResources(CuVSResources resources) { this.delegate = resources; @@ -269,5 +284,18 @@ public Path tempDirectory() { public String toString() { return "ManagedCuVSResources[delegate=" + delegate + "]"; } + + void lock(Runnable unlockAction) { + this.unlockAction = unlockAction; + } + + void unlock() { + unlockAction.run(); + unlockAction = NOT_LOCKED; + } + + boolean isLocked() { + return unlockAction != NOT_LOCKED; + } } } diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java index 2f4ee6180af1d..20d7612533cc4 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java @@ -178,6 +178,8 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO var started = System.nanoTime(); var fieldInfo = field.fieldInfo; + CagraIndexParams cagraIndexParams = createCagraIndexParams(fieldInfo.getVectorSimilarityFunction()); + var numVectors = field.flatFieldVectorsWriter.getVectors().size(); if (numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD) { if (logger.isDebugEnabled()) { @@ -193,7 +195,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO try ( var resourcesHolder = new ResourcesHolder( cuVSResourceManager, - cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), CuVSMatrix.DataType.FLOAT) + cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), CuVSMatrix.DataType.FLOAT, cagraIndexParams) ) ) { var builder = CuVSMatrix.deviceBuilder( @@ -206,7 +208,7 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO builder.addVector(vector); } try (var dataset = builder.build()) { - flushFieldWithGpuGraph(resourcesHolder, fieldInfo, dataset, sortMap); + flushFieldWithGpuGraph(resourcesHolder, fieldInfo, dataset, sortMap, cagraIndexParams); } } } @@ -224,13 +226,18 @@ private void flushFieldWithMockGraph(FieldInfo fieldInfo, int numVectors, Sorter } } - private void flushFieldWithGpuGraph(ResourcesHolder resourcesHolder, FieldInfo fieldInfo, CuVSMatrix dataset, Sorter.DocMap sortMap) - throws IOException { + private void flushFieldWithGpuGraph( + ResourcesHolder resourcesHolder, + FieldInfo fieldInfo, + CuVSMatrix dataset, + Sorter.DocMap sortMap, + CagraIndexParams cagraIndexParams + ) throws IOException { if (sortMap == null) { - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } else { // TODO: use sortMap - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } } @@ -262,14 +269,19 @@ public long ramBytesUsed() { return total; } - private void generateGpuGraphAndWriteMeta(ResourcesHolder resourcesHolder, FieldInfo fieldInfo, CuVSMatrix dataset) throws IOException { + private void generateGpuGraphAndWriteMeta( + ResourcesHolder resourcesHolder, + FieldInfo fieldInfo, + CuVSMatrix dataset, + CagraIndexParams cagraIndexParams + ) throws IOException { try { assert dataset.size() >= MIN_NUM_VECTORS_FOR_GPU_BUILD; long vectorIndexOffset = vectorIndex.getFilePointer(); int[][] graphLevelNodeOffsets = new int[1][]; final HnswGraph graph; - try (var index = buildGPUIndex(resourcesHolder.resources(), fieldInfo.getVectorSimilarityFunction(), dataset)) { + try (var index = buildGPUIndex(resourcesHolder.resources(), cagraIndexParams, dataset)) { assert index != null : "GPU index should be built for field: " + fieldInfo.name; var deviceGraph = index.getGraph(); var graphSize = deviceGraph.size() * deviceGraph.columns() * Integer.BYTES; @@ -309,9 +321,20 @@ private void generateMockGraphAndWriteMeta(FieldInfo fieldInfo, int datasetSize) private CagraIndex buildGPUIndex( CuVSResourceManager.ManagedCuVSResources cuVSResources, - VectorSimilarityFunction similarityFunction, + CagraIndexParams cagraIndexParams, CuVSMatrix dataset ) throws Throwable { + long startTime = System.nanoTime(); + var indexBuilder = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(cagraIndexParams); + var index = indexBuilder.build(); + cuVSResourceManager.finishedComputation(cuVSResources); + if (logger.isDebugEnabled()) { + logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, dataset.size()); + } + return index; + } + + private CagraIndexParams createCagraIndexParams(VectorSimilarityFunction similarityFunction) { CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) { case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded; case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded; @@ -328,22 +351,13 @@ private CagraIndex buildGPUIndex( }; // TODO: expose cagra index params for algorithm, NNDescentNumIterations - CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use? + return new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use? .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) .withGraphDegree(M) .withIntermediateGraphDegree(beamWidth) .withNNDescentNumIterations(5) .withMetric(distanceType) .build(); - - long startTime = System.nanoTime(); - var indexBuilder = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(params); - var index = indexBuilder.build(); - cuVSResourceManager.finishedComputation(cuVSResources); - if (logger.isDebugEnabled()) { - logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, dataset.size()); - } - return index; } private HnswGraph writeGraph(CuVSMatrix cagraGraph, int[][] levelNodeOffsets) throws IOException { @@ -500,6 +514,9 @@ private void mergeByteVectorField( var vectorValues = randomScorerSupplier == null ? null : VectorsFormatReflectionUtils.getByteScoringSupplierVectorOrNull(randomScorerSupplier); + + CagraIndexParams cagraIndexParams = createCagraIndexParams(fieldInfo.getVectorSimilarityFunction()); + if (vectorValues != null) { IndexInput slice = vectorValues.getSlice(); var input = FilterIndexInput.unwrapOnlyTest(slice); @@ -533,10 +550,10 @@ private void mergeByteVectorField( var dataset = DatasetUtilsImpl.fromMemorySegment(packedSegment, numVectors, packedRowSize, dataType); var resourcesHolder = new ResourcesHolder( cuVSResourceManager, - cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType) + cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams) ) ) { - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } } } else { @@ -557,10 +574,10 @@ private void mergeByteVectorField( var dataset = builder.build(); var resourcesHolder = new ResourcesHolder( cuVSResourceManager, - cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType) + cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams) ) ) { - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } } } else { @@ -578,10 +595,10 @@ private void mergeByteVectorField( var dataset = builder.build(); var resourcesHolder = new ResourcesHolder( cuVSResourceManager, - cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType) + cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams) ) ) { - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } } } @@ -595,6 +612,8 @@ private void mergeFloatVectorField( var vectorValues = randomScorerSupplier == null ? null : VectorsFormatReflectionUtils.getFloatScoringSupplierVectorOrNull(randomScorerSupplier); + CagraIndexParams cagraIndexParams = createCagraIndexParams(fieldInfo.getVectorSimilarityFunction()); + if (vectorValues != null) { IndexInput slice = vectorValues.getSlice(); var input = FilterIndexInput.unwrapOnlyTest(slice); @@ -605,10 +624,10 @@ private void mergeFloatVectorField( .fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType); var resourcesHolder = new ResourcesHolder( cuVSResourceManager, - cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType) + cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams) ) ) { - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } } else { logger.info( @@ -628,10 +647,10 @@ private void mergeFloatVectorField( var dataset = builder.build(); var resourcesHolder = new ResourcesHolder( cuVSResourceManager, - cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType) + cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams) ) ) { - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } } } else { @@ -650,10 +669,10 @@ private void mergeFloatVectorField( var dataset = builder.build(); var resourcesHolder = new ResourcesHolder( cuVSResourceManager, - cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType) + cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType, cagraIndexParams) ) ) { - generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset); + generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset, cagraIndexParams); } } } diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUMemoryService.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUMemoryService.java new file mode 100644 index 0000000000000..f33325a5a7ff1 --- /dev/null +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUMemoryService.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.gpu.codec; + +import com.nvidia.cuvs.CuVSResources; + +/** + * Abstracts GPU memory tracking (total vs available) + */ +interface GPUMemoryService { + + long totalMemoryInBytes(CuVSResources res); + + long availableMemoryInBytes(CuVSResources res); + + void reserveMemory(long memoryInBytes); + + void releaseMemory(long memoryInBytes); +} diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/RealGPUMemoryService.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/RealGPUMemoryService.java new file mode 100644 index 0000000000000..a87294900da9f --- /dev/null +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/RealGPUMemoryService.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.gpu.codec; + +import com.nvidia.cuvs.CuVSResources; +import com.nvidia.cuvs.GPUInfoProvider; + +/** + * A {@link GPUMemoryService} that tracks how much memory is currently used/available on a GPU by using the GPU free/total memory APIs + * (via a {@link GPUInfoProvider}) + */ +class RealGPUMemoryService implements GPUMemoryService { + private final GPUInfoProvider gpuInfoProvider; + + RealGPUMemoryService(GPUInfoProvider gpuInfoProvider) { + this.gpuInfoProvider = gpuInfoProvider; + } + + @Override + public long totalMemoryInBytes(CuVSResources res) { + return gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes(); + } + + @Override + public long availableMemoryInBytes(CuVSResources res) { + return gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes(); + } + + @Override + public void reserveMemory(long memoryInBytes) { + // No-op + } + + @Override + public void releaseMemory(long memoryInBytes) { + // No-op + } +} diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/TrackingGPUMemoryService.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/TrackingGPUMemoryService.java new file mode 100644 index 0000000000000..b9f5c5f7ff82e --- /dev/null +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/TrackingGPUMemoryService.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.gpu.codec; + +import com.nvidia.cuvs.CuVSResources; + +/** + * A {@link GPUMemoryService} that tracks manually how much memory is currently estimated to be used/available on a GPU. + * This implementation is useful when we are not able to use a "Real memory" measurement; for example, if we are using pooled RMM memory, + * the pool will permanently occupy most of the GPU RAM, allocations will happen inside the pool, and the "Real memory" measurement API + * will always report a (tiny) fixed amount of free memory (whatever is not in the pool). + */ +class TrackingGPUMemoryService implements GPUMemoryService { + + private final long totalMemoryInBytes; + private long availableMemoryInBytes; + + TrackingGPUMemoryService(long totalMemoryInBytes) { + this.totalMemoryInBytes = totalMemoryInBytes; + this.availableMemoryInBytes = totalMemoryInBytes; + } + + @Override + public long totalMemoryInBytes(CuVSResources res) { + return totalMemoryInBytes; + } + + @Override + public long availableMemoryInBytes(CuVSResources res) { + return availableMemoryInBytes; + } + + @Override + public void reserveMemory(long memoryInBytes) { + availableMemoryInBytes -= memoryInBytes; + } + + @Override + public void releaseMemory(long memoryInBytes) { + availableMemoryInBytes += memoryInBytes; + } +} diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java index b466f37cbe9c9..e704dbb1f02a6 100644 --- a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java +++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java @@ -7,26 +7,24 @@ package org.elasticsearch.xpack.gpu.codec; +import com.nvidia.cuvs.CagraIndexParams; +import com.nvidia.cuvs.CuVSIvfPqIndexParams; +import com.nvidia.cuvs.CuVSIvfPqParams; import com.nvidia.cuvs.CuVSMatrix; import com.nvidia.cuvs.CuVSResources; -import com.nvidia.cuvs.CuVSResourcesInfo; -import com.nvidia.cuvs.GPUInfo; -import com.nvidia.cuvs.GPUInfoProvider; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.test.ESTestCase; import java.nio.file.Path; -import java.util.ArrayList; -import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.LongSupplier; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; public class CuVSResourceManagerTests extends ESTestCase { @@ -35,16 +33,16 @@ public class CuVSResourceManagerTests extends ESTestCase { public static final long TOTAL_DEVICE_MEMORY_IN_BYTES = 256L * 1024 * 1024; - public void testBasic() throws InterruptedException { + private static void testBasic(CagraIndexParams params) throws InterruptedException { var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); - var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); + var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); assertThat(res1.toString(), containsString("id=0")); assertThat(res2.toString(), containsString("id=1")); mgr.release(res1); mgr.release(res2); - res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); - res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); + res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); assertThat(res1.toString(), containsString("id=0")); assertThat(res2.toString(), containsString("id=1")); mgr.release(res1); @@ -52,15 +50,44 @@ public void testBasic() throws InterruptedException { mgr.shutdown(); } - public void testBlocking() throws Exception { + public void testBasicWithNNDescent() throws InterruptedException { + testBasic(createNnDescentParams()); + } + + public void testBasicWithIvfPq() throws InterruptedException { + testBasic(createIvfPqParams()); + } + + public void testMultipleAcquireRelease() throws InterruptedException { var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); - var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createNnDescentParams()); + var res2 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createIvfPqParams()); + assertThat(res1.toString(), containsString("id=0")); + assertThat(res2.toString(), containsString("id=1")); + assertThat(mgr.availableMemory(), lessThan(TOTAL_DEVICE_MEMORY_IN_BYTES / 2)); + mgr.release(res1); + mgr.release(res2); + assertThat(mgr.availableMemory(), equalTo(TOTAL_DEVICE_MEMORY_IN_BYTES)); + res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createNnDescentParams()); + res2 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createIvfPqParams()); + assertThat(res1.toString(), containsString("id=0")); + assertThat(res2.toString(), containsString("id=1")); + assertThat(mgr.availableMemory(), lessThan(TOTAL_DEVICE_MEMORY_IN_BYTES / 2)); + mgr.release(res1); + mgr.release(res2); + assertThat(mgr.availableMemory(), equalTo(TOTAL_DEVICE_MEMORY_IN_BYTES)); + mgr.shutdown(); + } + + private static void testBlocking(CagraIndexParams params) throws Exception { + var mgr = new MockPoolingCuVSResourceManager(2); + var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); + var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); AtomicReference holder = new AtomicReference<>(); Thread t = new Thread(() -> { try { - var res3 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res3 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); holder.set(res3); } catch (InterruptedException e) { throw new AssertionError(e); @@ -75,14 +102,21 @@ public void testBlocking() throws Exception { mgr.shutdown(); } - public void testBlockingOnInsufficientMemory() throws Exception { - var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT); + public void testBlockingWithNNDescent() throws Exception { + testBlocking(createNnDescentParams()); + } + + public void testBlockingWithIvfPq() throws Exception { + testBlocking(createIvfPqParams()); + } + + private static void testBlockingOnInsufficientMemory(CagraIndexParams params, CuVSResourceManager mgr) throws Exception { + var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, params); AtomicReference holder = new AtomicReference<>(); Thread t = new Thread(() -> { try { - var res2 = mgr.acquire((16 * 1024) + 1, 1024, CuVSMatrix.DataType.FLOAT); + var res2 = mgr.acquire((16 * 1024) + 1, 1024, CuVSMatrix.DataType.FLOAT, params); holder.set(res2); } catch (InterruptedException e) { throw new AssertionError(e); @@ -97,14 +131,23 @@ public void testBlockingOnInsufficientMemory() throws Exception { mgr.shutdown(); } - public void testNotBlockingOnSufficientMemory() throws Exception { + public void testBlockingOnInsufficientMemoryNnDescent() throws Exception { var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT); + testBlockingOnInsufficientMemory(createNnDescentParams(), mgr); + } + + public void testBlockingOnInsufficientMemoryIvfPq() throws Exception { + var mgr = new MockPoolingCuVSResourceManager(2, 32L * 1024 * 1024); + testBlockingOnInsufficientMemory(createIvfPqParams(), mgr); + } + + private static void testNotBlockingOnSufficientMemory(CagraIndexParams params, CuVSResourceManager mgr) throws Exception { + var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, params); AtomicReference holder = new AtomicReference<>(); Thread t = new Thread(() -> { try { - var res2 = mgr.acquire((16 * 1024) - 1, 1024, CuVSMatrix.DataType.FLOAT); + var res2 = mgr.acquire((16 * 1024) - 1000, 1024, CuVSMatrix.DataType.FLOAT, params); holder.set(res2); } catch (InterruptedException e) { throw new AssertionError(e); @@ -117,9 +160,19 @@ public void testNotBlockingOnSufficientMemory() throws Exception { mgr.shutdown(); } + public void testNotBlockingOnSufficientMemoryNnDescent() throws Exception { + var mgr = new MockPoolingCuVSResourceManager(2); + testNotBlockingOnSufficientMemory(createNnDescentParams(), mgr); + } + + public void testNotBlockingOnSufficientMemoryIvfPq() throws Exception { + var mgr = new MockPoolingCuVSResourceManager(2, 32L * 1024 * 1024); + testNotBlockingOnSufficientMemory(createIvfPqParams(), mgr); + } + public void testManagedResIsNotClosable() throws Exception { var mgr = new MockPoolingCuVSResourceManager(1); - var res = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, createNnDescentParams()); assertThrows(UnsupportedOperationException.class, res::close); mgr.release(res); mgr.shutdown(); @@ -127,54 +180,55 @@ public void testManagedResIsNotClosable() throws Exception { public void testDoubleRelease() throws InterruptedException { var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); - var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, createNnDescentParams()); + var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, createNnDescentParams()); mgr.release(res1); mgr.release(res2); assertThrows(AssertionError.class, () -> mgr.release(randomFrom(res1, res2))); mgr.shutdown(); } + private static CagraIndexParams createNnDescentParams() { + return new CagraIndexParams.Builder().withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .withNNDescentNumIterations(5) + .build(); + } + + private static CagraIndexParams createIvfPqParams() { + return new CagraIndexParams.Builder().withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.IVF_PQ) + .withCuVSIvfPqParams( + new CuVSIvfPqParams.Builder().withCuVSIvfPqIndexParams( + new CuVSIvfPqIndexParams.Builder().withPqBits(4).withPqDim(1024).build() + ).build() + ) + .build(); + } + static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager { private final AtomicInteger idGenerator = new AtomicInteger(); - private final List allocations; + private final GPUMemoryService gpuMemoryService; MockPoolingCuVSResourceManager(int capacity) { - this(capacity, new ArrayList<>()); + this(capacity, TOTAL_DEVICE_MEMORY_IN_BYTES); } - private MockPoolingCuVSResourceManager(int capacity, List allocationList) { - super(capacity, new MockGPUInfoProvider(() -> freeMemoryFunction(allocationList))); - this.allocations = allocationList; + MockPoolingCuVSResourceManager(int capacity, long totalMemoryInBytes) { + this(capacity, new TrackingGPUMemoryService(totalMemoryInBytes)); } - private static long freeMemoryFunction(List allocations) { - return TOTAL_DEVICE_MEMORY_IN_BYTES - allocations.stream().mapToLong(x -> x).sum(); + private MockPoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) { + super(capacity, gpuMemoryService); + this.gpuMemoryService = gpuMemoryService; } - @Override - protected CuVSResources createNew() { - return new MockCuVSResources(idGenerator.getAndIncrement()); + long availableMemory() { + return gpuMemoryService.availableMemoryInBytes(null); } @Override - public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException { - var res = super.acquire(numVectors, dims, dataType); - long memory = (long) (numVectors * dims * Float.BYTES - * CuVSResourceManager.PoolingCuVSResourceManager.GPU_COMPUTATION_MEMORY_FACTOR); - allocations.add(memory); - log.info("Added [{}]", memory); - return res; - } - - @Override - public void release(ManagedCuVSResources resources) { - if (allocations.isEmpty() == false) { - var x = allocations.removeLast(); - log.info("Removed [{}]", x); - } - super.release(resources); + protected CuVSResources createNew() { + return new MockCuVSResources(idGenerator.getAndIncrement()); } } @@ -209,27 +263,4 @@ public String toString() { return "MockCuVSResources[id=" + id + "]"; } } - - private static class MockGPUInfoProvider implements GPUInfoProvider { - private final LongSupplier freeMemorySupplier; - - MockGPUInfoProvider(LongSupplier freeMemorySupplier) { - this.freeMemorySupplier = freeMemorySupplier; - } - - @Override - public List availableGPUs() { - throw new UnsupportedOperationException(); - } - - @Override - public List compatibleGPUs() { - throw new UnsupportedOperationException(); - } - - @Override - public CuVSResourcesInfo getCurrentInfo(CuVSResources cuVSResources) { - return new CuVSResourcesInfo(freeMemorySupplier.getAsLong(), TOTAL_DEVICE_MEMORY_IN_BYTES); - } - } }