Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
package org.elasticsearch.xpack.gpu.codec;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.GPUInfoProvider;
import com.nvidia.cuvs.spi.CuVSProvider;

import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.gpu.GPUSupport;

import java.nio.file.Path;
import java.util.Objects;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
* A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU.
Expand Down Expand Up @@ -65,35 +68,104 @@ static CuVSResourceManager pooling() {
*/
class PoolingCuVSResourceManager implements CuVSResourceManager {

/** A multiplier on input data to account for intermediate and output data size required while processing it */
static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;
static final int MAX_RESOURCES = 2;
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(MAX_RESOURCES);
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
MAX_RESOURCES,
CuVSProvider.provider().gpuInfoProvider()
);

final BlockingQueue<ManagedCuVSResources> pool;
final int capacity;
int createdCount;
private final ManagedCuVSResources[] pool;
private final int capacity;
private final GPUInfoProvider gpuInfoProvider;
private int createdCount;

public PoolingCuVSResourceManager(int capacity) {
ReentrantLock lock = new ReentrantLock();
Condition enoughResourcesCondition = lock.newCondition();

public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
if (capacity < 1 || capacity > MAX_RESOURCES) {
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
}
this.capacity = capacity;
this.pool = new ArrayBlockingQueue<>(capacity);
this.gpuInfoProvider = gpuInfoProvider;
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
}

@Override
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
ManagedCuVSResources res = pool.poll();
if (res != null) {
private ManagedCuVSResources getResourceFromPool() {
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
if (res.locked == false) {
return res;
}
}
if (createdCount < capacity) {
var res = new ManagedCuVSResources(Objects.requireNonNull(createNew()));
pool[createdCount++] = res;
return res;
}
synchronized (this) {
if (createdCount < capacity) {
createdCount++;
return new ManagedCuVSResources(Objects.requireNonNull(createNew()));
return null;
}

private int numLockedResources() {
int lockedResources = 0;
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
if (res.locked) {
lockedResources++;
}
}
// Otherwise, wait for one to be released
return pool.take();
return lockedResources;
}

@Override
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
try {
lock.lock();

boolean allConditionsMet = false;
ManagedCuVSResources res = null;
while (allConditionsMet == false) {
res = getResourceFromPool();

final boolean enoughMemory;
if (res != null) {
// If no resource in the pool is locked, short circuit to avoid livelock
if (numLockedResources() == 0) {
break;
}
// Check resources availability
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims);
if (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes()) {
throw new IllegalArgumentException(
Strings.format(
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]",
numVectors,
dims,
gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes() / (1024L * 1024L)
)
);
}
enoughMemory = requiredMemoryInBytes <= gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
} else {
enoughMemory = false;
}
// TODO: add enoughComputation / enoughComputationCondition here
allConditionsMet = enoughMemory; // && enoughComputation
if (allConditionsMet == false) {
enoughResourcesCondition.await();
}
}
res.locked = true;
return res;
} finally {
lock.unlock();
}
}

private long estimateRequiredMemory(int numVectors, int dims) {
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * Float.BYTES);
}

// visible for testing
Expand All @@ -104,27 +176,36 @@ protected CuVSResources createNew() {
@Override
public void finishedComputation(ManagedCuVSResources resources) {
// currently does nothing, but could allow acquire to return possibly blocked resources
// enoughResourcesCondition.signalAll()
}

@Override
public void release(ManagedCuVSResources resources) {
var added = pool.offer(Objects.requireNonNull(resources));
assert added : "Failed to release resource back to pool";
try {
lock.lock();
assert resources.locked;
resources.locked = false;
enoughResourcesCondition.signalAll();
} finally {
lock.unlock();
}
}

@Override
public void shutdown() {
for (ManagedCuVSResources res : pool) {
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
assert res != null;
res.delegate.close();
}
pool.clear();
}
}

/** A managed resource. Cannot be closed. */
final class ManagedCuVSResources implements CuVSResources {

final CuVSResources delegate;
boolean locked = false;

ManagedCuVSResources(CuVSResources resources) {
this.delegate = resources;
Expand All @@ -135,6 +216,11 @@ public ScopedAccess access() {
return delegate.access();
}

@Override
public int deviceId() {
return delegate.deviceId();
}

@Override
public void close() {
throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,32 @@
package org.elasticsearch.xpack.gpu.codec;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.CuVSResourcesInfo;
import com.nvidia.cuvs.GPUInfo;
import com.nvidia.cuvs.GPUInfoProvider;

import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.ESTestCase;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.LongSupplier;

import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;

public class CuVSResourceManagerTests extends ESTestCase {

private static final Logger log = LogManager.getLogger(CuVSResourceManagerTests.class);

public static final long TOTAL_DEVICE_MEMORY_IN_BYTES = 256L * 1024 * 1024;

public void testBasic() throws InterruptedException {
var mgr = new MockPoolingCuVSResourceManager(2);
var res1 = mgr.acquire(0, 0);
Expand Down Expand Up @@ -60,10 +74,52 @@ public void testBlocking() throws Exception {
mgr.shutdown();
}

public void testBlockingOnInsufficientMemory() throws Exception {
var mgr = new MockPoolingCuVSResourceManager(2);
var res1 = mgr.acquire(16 * 1024, 1024);

AtomicReference<CuVSResources> holder = new AtomicReference<>();
Thread t = new Thread(() -> {
try {
var res2 = mgr.acquire((16 * 1024) + 1, 1024);
holder.set(res2);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
});
t.start();
Thread.sleep(1_000);
assertNull(holder.get());
mgr.release(res1);
t.join();
assertThat(holder.get().toString(), anyOf(containsString("id=0"), containsString("id=1")));
mgr.shutdown();
}

public void testNotBlockingOnSufficientMemory() throws Exception {
var mgr = new MockPoolingCuVSResourceManager(2);
var res1 = mgr.acquire(16 * 1024, 1024);

AtomicReference<CuVSResources> holder = new AtomicReference<>();
Thread t = new Thread(() -> {
try {
var res2 = mgr.acquire((16 * 1024) - 1, 1024);
holder.set(res2);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
});
t.start();
t.join(5_000);
assertNotNull(holder.get());
assertThat(holder.get().toString(), not(equalTo(res1.toString())));
mgr.shutdown();
}

public void testManagedResIsNotClosable() throws Exception {
var mgr = new MockPoolingCuVSResourceManager(1);
var res = mgr.acquire(0, 0);
assertThrows(UnsupportedOperationException.class, () -> res.close());
assertThrows(UnsupportedOperationException.class, res::close);
mgr.release(res);
mgr.shutdown();
}
Expand All @@ -80,16 +136,45 @@ public void testDoubleRelease() throws InterruptedException {

static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager {

final AtomicInteger idGenerator = new AtomicInteger();
private final AtomicInteger idGenerator = new AtomicInteger();
private final List<Long> allocations;

MockPoolingCuVSResourceManager(int capacity) {
super(capacity);
this(capacity, new ArrayList<>());
}

private MockPoolingCuVSResourceManager(int capacity, List<Long> allocationList) {
super(capacity, new MockGPUInfoProvider(() -> freeMemoryFunction(allocationList)));
this.allocations = allocationList;
}

private static long freeMemoryFunction(List<Long> allocations) {
return TOTAL_DEVICE_MEMORY_IN_BYTES - allocations.stream().mapToLong(x -> x).sum();
}

@Override
protected CuVSResources createNew() {
return new MockCuVSResources(idGenerator.getAndIncrement());
}

@Override
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
var res = super.acquire(numVectors, dims);
long memory = (long) (numVectors * dims * Float.BYTES
* CuVSResourceManager.PoolingCuVSResourceManager.GPU_COMPUTATION_MEMORY_FACTOR);
allocations.add(memory);
log.info("Added [{}]", memory);
return res;
}

@Override
public void release(ManagedCuVSResources resources) {
if (allocations.isEmpty() == false) {
var x = allocations.removeLast();
log.info("Removed [{}]", x);
}
super.release(resources);
}
}

static class MockCuVSResources implements CuVSResources {
Expand All @@ -105,6 +190,11 @@ public ScopedAccess access() {
throw new UnsupportedOperationException();
}

@Override
public int deviceId() {
return 0;
}

@Override
public void close() {}

Expand All @@ -118,4 +208,27 @@ public String toString() {
return "MockCuVSResources[id=" + id + "]";
}
}

private static class MockGPUInfoProvider implements GPUInfoProvider {
private final LongSupplier freeMemorySupplier;

MockGPUInfoProvider(LongSupplier freeMemorySupplier) {
this.freeMemorySupplier = freeMemorySupplier;
}

@Override
public List<GPUInfo> availableGPUs() throws Throwable {
throw new UnsupportedOperationException();
}

@Override
public List<GPUInfo> compatibleGPUs() throws Throwable {
throw new UnsupportedOperationException();
}

@Override
public CuVSResourcesInfo getCurrentInfo(CuVSResources cuVSResources) {
return new CuVSResourcesInfo(freeMemorySupplier.getAsLong(), TOTAL_DEVICE_MEMORY_IN_BYTES);
}
}
}