Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@

package org.elasticsearch.xpack.gpu.codec;

import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.GPUInfoProvider;
import com.nvidia.cuvs.spi.CuVSProvider;

import org.elasticsearch.core.Strings;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.gpu.GPUSupport;

import java.nio.file.Path;
import java.util.Objects;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
* A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU.
Expand Down Expand Up @@ -44,7 +50,7 @@ public interface CuVSResourceManager {
// numVectors and dims are currently unused, but could be used along with GPU metadata,
// memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
// to give out a resources or not.
ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException;
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException;

/** Marks the resources as finished with regard to compute. */
void finishedComputation(ManagedCuVSResources resources);
Expand All @@ -65,35 +71,127 @@ static CuVSResourceManager pooling() {
*/
class PoolingCuVSResourceManager implements CuVSResourceManager {

static final Logger logger = LogManager.getLogger(CuVSResourceManager.class);

/** A multiplier on input data to account for intermediate and output data size required while processing it */
static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;
static final int MAX_RESOURCES = 2;
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(MAX_RESOURCES);
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
MAX_RESOURCES,
CuVSProvider.provider().gpuInfoProvider()
);

private final ManagedCuVSResources[] pool;
private final int capacity;
private final GPUInfoProvider gpuInfoProvider;
private int createdCount;

final BlockingQueue<ManagedCuVSResources> pool;
final int capacity;
int createdCount;
ReentrantLock lock = new ReentrantLock();
Condition enoughResourcesCondition = lock.newCondition();

public PoolingCuVSResourceManager(int capacity) {
public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
if (capacity < 1 || capacity > MAX_RESOURCES) {
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
}
this.capacity = capacity;
this.pool = new ArrayBlockingQueue<>(capacity);
this.gpuInfoProvider = gpuInfoProvider;
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
}

@Override
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
ManagedCuVSResources res = pool.poll();
if (res != null) {
private ManagedCuVSResources getResourceFromPool() {
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
if (res.locked == false) {
return res;
}
}
if (createdCount < capacity) {
var res = new ManagedCuVSResources(Objects.requireNonNull(createNew()));
pool[createdCount++] = res;
return res;
}
synchronized (this) {
if (createdCount < capacity) {
createdCount++;
return new ManagedCuVSResources(Objects.requireNonNull(createNew()));
return null;
}

private int numLockedResources() {
int lockedResources = 0;
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
if (res.locked) {
lockedResources++;
}
}
// Otherwise, wait for one to be released
return pool.take();
return lockedResources;
}

@Override
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException {
try {
lock.lock();

boolean allConditionsMet = false;
ManagedCuVSResources res = null;
while (allConditionsMet == false) {
res = getResourceFromPool();

final boolean enoughMemory;
if (res != null) {
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType);
logger.info(
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
numVectors,
dims,
dataType.name(),
requiredMemoryInBytes
);

// Check immutable constraints
long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes();
if (requiredMemoryInBytes > totalDeviceMemoryInBytes) {
String message = Strings.format(
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]",
numVectors,
dims,
totalDeviceMemoryInBytes
);
logger.error(message);
throw new IllegalArgumentException(message);
}

// If no resource in the pool is locked, short circuit to avoid livelock
if (numLockedResources() == 0) {
logger.info("No resources currently locked, proceeding");
break;
}

// Check resources availability
long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes;
logger.info("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes);
} else {
logger.info("No resources available in pool");
enoughMemory = false;
}
// TODO: add enoughComputation / enoughComputationCondition here
allConditionsMet = enoughMemory; // && enoughComputation
if (allConditionsMet == false) {
enoughResourcesCondition.await();
}
}
res.locked = true;
return res;
} finally {
lock.unlock();
}
}

private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) {
int elementTypeBytes = switch (dataType) {
case FLOAT -> Float.BYTES;
case INT, UINT -> Integer.BYTES;
case BYTE -> Byte.BYTES;
};
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes);
}

// visible for testing
Expand All @@ -103,28 +201,39 @@ protected CuVSResources createNew() {

@Override
public void finishedComputation(ManagedCuVSResources resources) {
logger.info("Computation finished");
// currently does nothing, but could allow acquire to return possibly blocked resources
// enoughResourcesCondition.signalAll()
}

@Override
public void release(ManagedCuVSResources resources) {
var added = pool.offer(Objects.requireNonNull(resources));
assert added : "Failed to release resource back to pool";
logger.info("Releasing resources to pool");
try {
lock.lock();
assert resources.locked;
resources.locked = false;
enoughResourcesCondition.signalAll();
} finally {
lock.unlock();
}
}

@Override
public void shutdown() {
for (ManagedCuVSResources res : pool) {
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
assert res != null;
res.delegate.close();
}
pool.clear();
}
}

/** A managed resource. Cannot be closed. */
final class ManagedCuVSResources implements CuVSResources {

final CuVSResources delegate;
boolean locked = false;

ManagedCuVSResources(CuVSResources resources) {
this.delegate = resources;
Expand All @@ -135,6 +244,11 @@ public ScopedAccess access() {
return delegate.access();
}

@Override
public int deviceId() {
return delegate.deviceId();
}

@Override
public void close() {
throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrV
mockGraph = writeGraph(vectors, graphLevelNodeOffsets);
} else {
var dataset = datasetOrVectors.dataset;
var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns());
var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns(), dataset.dataType());
try {
try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) {
assert index != null : "GPU index should be built for field: " + fieldInfo.name;
Expand Down
Loading