Skip to content

Commit 83aa729

Browse files
authored
PoolingCuVSResourceManager with memory availability (#133242)
1 parent ebb36e0 commit 83aa729

File tree

3 files changed

+265
-37
lines changed

3 files changed

+265
-37
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 137 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,20 @@
77

88
package org.elasticsearch.xpack.gpu.codec;
99

10+
import com.nvidia.cuvs.CuVSMatrix;
1011
import com.nvidia.cuvs.CuVSResources;
12+
import com.nvidia.cuvs.GPUInfoProvider;
13+
import com.nvidia.cuvs.spi.CuVSProvider;
1114

15+
import org.elasticsearch.core.Strings;
16+
import org.elasticsearch.logging.LogManager;
17+
import org.elasticsearch.logging.Logger;
1218
import org.elasticsearch.xpack.gpu.GPUSupport;
1319

1420
import java.nio.file.Path;
1521
import java.util.Objects;
16-
import java.util.concurrent.ArrayBlockingQueue;
17-
import java.util.concurrent.BlockingQueue;
22+
import java.util.concurrent.locks.Condition;
23+
import java.util.concurrent.locks.ReentrantLock;
1824

1925
/**
2026
* A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU.
@@ -44,7 +50,7 @@ public interface CuVSResourceManager {
4450
// numVectors and dims are currently unused, but could be used along with GPU metadata,
4551
// memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
4652
// to give out a resources or not.
47-
ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException;
53+
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException;
4854

4955
/** Marks the resources as finished with regard to compute. */
5056
void finishedComputation(ManagedCuVSResources resources);
@@ -65,35 +71,127 @@ static CuVSResourceManager pooling() {
6571
*/
6672
class PoolingCuVSResourceManager implements CuVSResourceManager {
6773

74+
static final Logger logger = LogManager.getLogger(CuVSResourceManager.class);
75+
76+
/** A multiplier on input data to account for intermediate and output data size required while processing it */
77+
static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;
6878
static final int MAX_RESOURCES = 2;
69-
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(MAX_RESOURCES);
79+
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
80+
MAX_RESOURCES,
81+
CuVSProvider.provider().gpuInfoProvider()
82+
);
83+
84+
private final ManagedCuVSResources[] pool;
85+
private final int capacity;
86+
private final GPUInfoProvider gpuInfoProvider;
87+
private int createdCount;
7088

71-
final BlockingQueue<ManagedCuVSResources> pool;
72-
final int capacity;
73-
int createdCount;
89+
ReentrantLock lock = new ReentrantLock();
90+
Condition enoughResourcesCondition = lock.newCondition();
7491

75-
public PoolingCuVSResourceManager(int capacity) {
92+
public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
7693
if (capacity < 1 || capacity > MAX_RESOURCES) {
7794
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
7895
}
7996
this.capacity = capacity;
80-
this.pool = new ArrayBlockingQueue<>(capacity);
97+
this.gpuInfoProvider = gpuInfoProvider;
98+
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
8199
}
82100

83-
@Override
84-
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
85-
ManagedCuVSResources res = pool.poll();
86-
if (res != null) {
101+
private ManagedCuVSResources getResourceFromPool() {
102+
for (int i = 0; i < createdCount; ++i) {
103+
var res = pool[i];
104+
if (res.locked == false) {
105+
return res;
106+
}
107+
}
108+
if (createdCount < capacity) {
109+
var res = new ManagedCuVSResources(Objects.requireNonNull(createNew()));
110+
pool[createdCount++] = res;
87111
return res;
88112
}
89-
synchronized (this) {
90-
if (createdCount < capacity) {
91-
createdCount++;
92-
return new ManagedCuVSResources(Objects.requireNonNull(createNew()));
113+
return null;
114+
}
115+
116+
private int numLockedResources() {
117+
int lockedResources = 0;
118+
for (int i = 0; i < createdCount; ++i) {
119+
var res = pool[i];
120+
if (res.locked) {
121+
lockedResources++;
93122
}
94123
}
95-
// Otherwise, wait for one to be released
96-
return pool.take();
124+
return lockedResources;
125+
}
126+
127+
@Override
128+
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException {
129+
try {
130+
lock.lock();
131+
132+
boolean allConditionsMet = false;
133+
ManagedCuVSResources res = null;
134+
while (allConditionsMet == false) {
135+
res = getResourceFromPool();
136+
137+
final boolean enoughMemory;
138+
if (res != null) {
139+
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType);
140+
logger.info(
141+
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
142+
numVectors,
143+
dims,
144+
dataType.name(),
145+
requiredMemoryInBytes
146+
);
147+
148+
// Check immutable constraints
149+
long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes();
150+
if (requiredMemoryInBytes > totalDeviceMemoryInBytes) {
151+
String message = Strings.format(
152+
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]",
153+
numVectors,
154+
dims,
155+
totalDeviceMemoryInBytes
156+
);
157+
logger.error(message);
158+
throw new IllegalArgumentException(message);
159+
}
160+
161+
// If no resource in the pool is locked, short circuit to avoid livelock
162+
if (numLockedResources() == 0) {
163+
logger.info("No resources currently locked, proceeding");
164+
break;
165+
}
166+
167+
// Check resources availability
168+
long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
169+
enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes;
170+
logger.info("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes);
171+
} else {
172+
logger.info("No resources available in pool");
173+
enoughMemory = false;
174+
}
175+
// TODO: add enoughComputation / enoughComputationCondition here
176+
allConditionsMet = enoughMemory; // && enoughComputation
177+
if (allConditionsMet == false) {
178+
enoughResourcesCondition.await();
179+
}
180+
}
181+
res.locked = true;
182+
return res;
183+
} finally {
184+
lock.unlock();
185+
}
186+
}
187+
188+
private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) {
189+
int elementTypeBytes = switch (dataType) {
190+
case FLOAT -> Float.BYTES;
191+
case INT, UINT -> Integer.BYTES;
192+
case BYTE -> Byte.BYTES;
193+
};
194+
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes);
97195
}
98196

99197
// visible for testing
@@ -103,28 +201,39 @@ protected CuVSResources createNew() {
103201

104202
@Override
105203
public void finishedComputation(ManagedCuVSResources resources) {
204+
logger.info("Computation finished");
106205
// currently does nothing, but could allow acquire to return possibly blocked resources
206+
// enoughResourcesCondition.signalAll()
107207
}
108208

109209
@Override
110210
public void release(ManagedCuVSResources resources) {
111-
var added = pool.offer(Objects.requireNonNull(resources));
112-
assert added : "Failed to release resource back to pool";
211+
logger.info("Releasing resources to pool");
212+
try {
213+
lock.lock();
214+
assert resources.locked;
215+
resources.locked = false;
216+
enoughResourcesCondition.signalAll();
217+
} finally {
218+
lock.unlock();
219+
}
113220
}
114221

115222
@Override
116223
public void shutdown() {
117-
for (ManagedCuVSResources res : pool) {
224+
for (int i = 0; i < createdCount; ++i) {
225+
var res = pool[i];
226+
assert res != null;
118227
res.delegate.close();
119228
}
120-
pool.clear();
121229
}
122230
}
123231

124232
/** A managed resource. Cannot be closed. */
125233
final class ManagedCuVSResources implements CuVSResources {
126234

127235
final CuVSResources delegate;
236+
boolean locked = false;
128237

129238
ManagedCuVSResources(CuVSResources resources) {
130239
this.delegate = resources;
@@ -135,6 +244,11 @@ public ScopedAccess access() {
135244
return delegate.access();
136245
}
137246

247+
@Override
248+
public int deviceId() {
249+
return delegate.deviceId();
250+
}
251+
138252
@Override
139253
public void close() {
140254
throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients");

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrV
250250
mockGraph = writeGraph(size, graphLevelNodeOffsets);
251251
} else {
252252
var dataset = datasetOrVectors.getDataset();
253-
var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns());
253+
var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns(), dataset.dataType());
254254
try {
255255
try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) {
256256
assert index != null : "GPU index should be built for field: " + fieldInfo.name;

0 commit comments

Comments
 (0)