diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java index e7977f28c9c22..26e4e94ed57ea 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java @@ -7,14 +7,20 @@ package org.elasticsearch.xpack.gpu.codec; +import com.nvidia.cuvs.CuVSMatrix; import com.nvidia.cuvs.CuVSResources; +import com.nvidia.cuvs.GPUInfoProvider; +import com.nvidia.cuvs.spi.CuVSProvider; +import org.elasticsearch.core.Strings; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.xpack.gpu.GPUSupport; import java.nio.file.Path; import java.util.Objects; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; /** * A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU. @@ -44,7 +50,7 @@ public interface CuVSResourceManager { // numVectors and dims are currently unused, but could be used along with GPU metadata, // memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims, // to give out a resources or not. - ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException; + ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException; /** Marks the resources as finished with regard to compute. */ void finishedComputation(ManagedCuVSResources resources); @@ -65,35 +71,127 @@ static CuVSResourceManager pooling() { */ class PoolingCuVSResourceManager implements CuVSResourceManager { + static final Logger logger = LogManager.getLogger(CuVSResourceManager.class); + + /** A multiplier on input data to account for intermediate and output data size required while processing it */ + static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0; static final int MAX_RESOURCES = 2; - static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(MAX_RESOURCES); + static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager( + MAX_RESOURCES, + CuVSProvider.provider().gpuInfoProvider() + ); + + private final ManagedCuVSResources[] pool; + private final int capacity; + private final GPUInfoProvider gpuInfoProvider; + private int createdCount; - final BlockingQueue pool; - final int capacity; - int createdCount; + ReentrantLock lock = new ReentrantLock(); + Condition enoughResourcesCondition = lock.newCondition(); - public PoolingCuVSResourceManager(int capacity) { + public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) { if (capacity < 1 || capacity > MAX_RESOURCES) { throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES); } this.capacity = capacity; - this.pool = new ArrayBlockingQueue<>(capacity); + this.gpuInfoProvider = gpuInfoProvider; + this.pool = new ManagedCuVSResources[MAX_RESOURCES]; } - @Override - public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException { - ManagedCuVSResources res = pool.poll(); - if (res != null) { + private ManagedCuVSResources getResourceFromPool() { + for (int i = 0; i < createdCount; ++i) { + var res = pool[i]; + if (res.locked == false) { + return res; + } + } + if (createdCount < capacity) { + var res = new ManagedCuVSResources(Objects.requireNonNull(createNew())); + pool[createdCount++] = res; return res; } - synchronized (this) { - if (createdCount < capacity) { - createdCount++; - return new ManagedCuVSResources(Objects.requireNonNull(createNew())); + return null; + } + + private int numLockedResources() { + int lockedResources = 0; + for (int i = 0; i < createdCount; ++i) { + var res = pool[i]; + if (res.locked) { + lockedResources++; } } - // Otherwise, wait for one to be released - return pool.take(); + return lockedResources; + } + + @Override + public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException { + try { + lock.lock(); + + boolean allConditionsMet = false; + ManagedCuVSResources res = null; + while (allConditionsMet == false) { + res = getResourceFromPool(); + + final boolean enoughMemory; + if (res != null) { + long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType); + logger.info( + "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]", + numVectors, + dims, + dataType.name(), + requiredMemoryInBytes + ); + + // Check immutable constraints + long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes(); + if (requiredMemoryInBytes > totalDeviceMemoryInBytes) { + String message = Strings.format( + "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]", + numVectors, + dims, + totalDeviceMemoryInBytes + ); + logger.error(message); + throw new IllegalArgumentException(message); + } + + // If no resource in the pool is locked, short circuit to avoid livelock + if (numLockedResources() == 0) { + logger.info("No resources currently locked, proceeding"); + break; + } + + // Check resources availability + long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes(); + enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes; + logger.info("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes); + } else { + logger.info("No resources available in pool"); + enoughMemory = false; + } + // TODO: add enoughComputation / enoughComputationCondition here + allConditionsMet = enoughMemory; // && enoughComputation + if (allConditionsMet == false) { + enoughResourcesCondition.await(); + } + } + res.locked = true; + return res; + } finally { + lock.unlock(); + } + } + + private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) { + int elementTypeBytes = switch (dataType) { + case FLOAT -> Float.BYTES; + case INT, UINT -> Integer.BYTES; + case BYTE -> Byte.BYTES; + }; + return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes); } // visible for testing @@ -103,21 +201,31 @@ protected CuVSResources createNew() { @Override public void finishedComputation(ManagedCuVSResources resources) { + logger.info("Computation finished"); // currently does nothing, but could allow acquire to return possibly blocked resources + // enoughResourcesCondition.signalAll() } @Override public void release(ManagedCuVSResources resources) { - var added = pool.offer(Objects.requireNonNull(resources)); - assert added : "Failed to release resource back to pool"; + logger.info("Releasing resources to pool"); + try { + lock.lock(); + assert resources.locked; + resources.locked = false; + enoughResourcesCondition.signalAll(); + } finally { + lock.unlock(); + } } @Override public void shutdown() { - for (ManagedCuVSResources res : pool) { + for (int i = 0; i < createdCount; ++i) { + var res = pool[i]; + assert res != null; res.delegate.close(); } - pool.clear(); } } @@ -125,6 +233,7 @@ public void shutdown() { final class ManagedCuVSResources implements CuVSResources { final CuVSResources delegate; + boolean locked = false; ManagedCuVSResources(CuVSResources resources) { this.delegate = resources; @@ -135,6 +244,11 @@ public ScopedAccess access() { return delegate.access(); } + @Override + public int deviceId() { + return delegate.deviceId(); + } + @Override public void close() { throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients"); diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java index b8215e4fbc702..5a166fd2eeac0 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java @@ -250,7 +250,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrV mockGraph = writeGraph(size, graphLevelNodeOffsets); } else { var dataset = datasetOrVectors.getDataset(); - var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns()); + var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns(), dataset.dataType()); try { try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) { assert index != null : "GPU index should be built for field: " + fieldInfo.name; diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java index a5bac96cc3b51..b466f37cbe9c9 100644 --- a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java +++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java @@ -7,29 +7,44 @@ package org.elasticsearch.xpack.gpu.codec; +import com.nvidia.cuvs.CuVSMatrix; import com.nvidia.cuvs.CuVSResources; +import com.nvidia.cuvs.CuVSResourcesInfo; +import com.nvidia.cuvs.GPUInfo; +import com.nvidia.cuvs.GPUInfoProvider; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.test.ESTestCase; import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; public class CuVSResourceManagerTests extends ESTestCase { + private static final Logger log = LogManager.getLogger(CuVSResourceManagerTests.class); + + public static final long TOTAL_DEVICE_MEMORY_IN_BYTES = 256L * 1024 * 1024; + public void testBasic() throws InterruptedException { var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(0, 0); - var res2 = mgr.acquire(0, 0); + var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); assertThat(res1.toString(), containsString("id=0")); assertThat(res2.toString(), containsString("id=1")); mgr.release(res1); mgr.release(res2); - res1 = mgr.acquire(0, 0); - res2 = mgr.acquire(0, 0); + res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); assertThat(res1.toString(), containsString("id=0")); assertThat(res2.toString(), containsString("id=1")); mgr.release(res1); @@ -39,13 +54,13 @@ public void testBasic() throws InterruptedException { public void testBlocking() throws Exception { var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(0, 0); - var res2 = mgr.acquire(0, 0); + var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); AtomicReference holder = new AtomicReference<>(); Thread t = new Thread(() -> { try { - var res3 = mgr.acquire(0, 0); + var res3 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); holder.set(res3); } catch (InterruptedException e) { throw new AssertionError(e); @@ -60,18 +75,60 @@ public void testBlocking() throws Exception { mgr.shutdown(); } + public void testBlockingOnInsufficientMemory() throws Exception { + var mgr = new MockPoolingCuVSResourceManager(2); + var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT); + + AtomicReference holder = new AtomicReference<>(); + Thread t = new Thread(() -> { + try { + var res2 = mgr.acquire((16 * 1024) + 1, 1024, CuVSMatrix.DataType.FLOAT); + holder.set(res2); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + }); + t.start(); + Thread.sleep(1_000); + assertNull(holder.get()); + mgr.release(res1); + t.join(); + assertThat(holder.get().toString(), anyOf(containsString("id=0"), containsString("id=1"))); + mgr.shutdown(); + } + + public void testNotBlockingOnSufficientMemory() throws Exception { + var mgr = new MockPoolingCuVSResourceManager(2); + var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT); + + AtomicReference holder = new AtomicReference<>(); + Thread t = new Thread(() -> { + try { + var res2 = mgr.acquire((16 * 1024) - 1, 1024, CuVSMatrix.DataType.FLOAT); + holder.set(res2); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + }); + t.start(); + t.join(5_000); + assertNotNull(holder.get()); + assertThat(holder.get().toString(), not(equalTo(res1.toString()))); + mgr.shutdown(); + } + public void testManagedResIsNotClosable() throws Exception { var mgr = new MockPoolingCuVSResourceManager(1); - var res = mgr.acquire(0, 0); - assertThrows(UnsupportedOperationException.class, () -> res.close()); + var res = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + assertThrows(UnsupportedOperationException.class, res::close); mgr.release(res); mgr.shutdown(); } public void testDoubleRelease() throws InterruptedException { var mgr = new MockPoolingCuVSResourceManager(2); - var res1 = mgr.acquire(0, 0); - var res2 = mgr.acquire(0, 0); + var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); + var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT); mgr.release(res1); mgr.release(res2); assertThrows(AssertionError.class, () -> mgr.release(randomFrom(res1, res2))); @@ -80,16 +137,45 @@ public void testDoubleRelease() throws InterruptedException { static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager { - final AtomicInteger idGenerator = new AtomicInteger(); + private final AtomicInteger idGenerator = new AtomicInteger(); + private final List allocations; MockPoolingCuVSResourceManager(int capacity) { - super(capacity); + this(capacity, new ArrayList<>()); + } + + private MockPoolingCuVSResourceManager(int capacity, List allocationList) { + super(capacity, new MockGPUInfoProvider(() -> freeMemoryFunction(allocationList))); + this.allocations = allocationList; + } + + private static long freeMemoryFunction(List allocations) { + return TOTAL_DEVICE_MEMORY_IN_BYTES - allocations.stream().mapToLong(x -> x).sum(); } @Override protected CuVSResources createNew() { return new MockCuVSResources(idGenerator.getAndIncrement()); } + + @Override + public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException { + var res = super.acquire(numVectors, dims, dataType); + long memory = (long) (numVectors * dims * Float.BYTES + * CuVSResourceManager.PoolingCuVSResourceManager.GPU_COMPUTATION_MEMORY_FACTOR); + allocations.add(memory); + log.info("Added [{}]", memory); + return res; + } + + @Override + public void release(ManagedCuVSResources resources) { + if (allocations.isEmpty() == false) { + var x = allocations.removeLast(); + log.info("Removed [{}]", x); + } + super.release(resources); + } } static class MockCuVSResources implements CuVSResources { @@ -105,6 +191,11 @@ public ScopedAccess access() { throw new UnsupportedOperationException(); } + @Override + public int deviceId() { + return 0; + } + @Override public void close() {} @@ -118,4 +209,27 @@ public String toString() { return "MockCuVSResources[id=" + id + "]"; } } + + private static class MockGPUInfoProvider implements GPUInfoProvider { + private final LongSupplier freeMemorySupplier; + + MockGPUInfoProvider(LongSupplier freeMemorySupplier) { + this.freeMemorySupplier = freeMemorySupplier; + } + + @Override + public List availableGPUs() { + throw new UnsupportedOperationException(); + } + + @Override + public List compatibleGPUs() { + throw new UnsupportedOperationException(); + } + + @Override + public CuVSResourcesInfo getCurrentInfo(CuVSResources cuVSResources) { + return new CuVSResourcesInfo(freeMemorySupplier.getAsLong(), TOTAL_DEVICE_MEMORY_IN_BYTES); + } + } }