Skip to content

Commit 729ef41

Browse files
authored
[GPU] Extend CuVSResourcesManager (elastic#137588) (elastic#137621)
CuVSResourcesManager has the purpose of controlling access to resources to ensure a correct level of parallelism (allowing more than 1 GPU thread, but having a reasonable upper bound) and controlling the amount of GPU memory needed to prevent CUDA out-of-memory errors. This PR extends the memory control part by introducing different strategies for memory accounting ("real", based on API calls to the device, and "tracking", which remembers the amount of memory requested during acquisition) and different estimations based on the CAGRA graph build algorithm. The former will allow us to use pooled memory (where the amount of available memory will be different from the free device memory), the latter to use the IVFPQ CAGRA graph build algorithm for larger datasets.
1 parent 28b1e7a commit 729ef41

File tree

6 files changed

+330
-138
lines changed

6 files changed

+330
-138
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
package org.elasticsearch.xpack.gpu.codec;
99

10+
import com.nvidia.cuvs.CagraIndexParams;
1011
import com.nvidia.cuvs.CuVSMatrix;
1112
import com.nvidia.cuvs.CuVSResources;
12-
import com.nvidia.cuvs.GPUInfoProvider;
1313
import com.nvidia.cuvs.spi.CuVSProvider;
1414

1515
import org.elasticsearch.core.Strings;
@@ -47,10 +47,8 @@ public interface CuVSResourceManager {
4747
* effect on GPU memory and compute usage to determine whether to give out
4848
* another resource or wait for a resources to be returned before giving out another.
4949
*/
50-
// numVectors and dims are currently unused, but could be used along with GPU metadata,
51-
// memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
52-
// to give out a resources or not.
53-
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException;
50+
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams)
51+
throws InterruptedException;
5452

5553
/** Marks the resources as finished with regard to compute. */
5654
void finishedComputation(ManagedCuVSResources resources);
@@ -80,31 +78,31 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
8078
static class Holder {
8179
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
8280
MAX_RESOURCES,
83-
CuVSProvider.provider().gpuInfoProvider()
81+
new RealGPUMemoryService(CuVSProvider.provider().gpuInfoProvider())
8482
);
8583
}
8684

8785
private final ManagedCuVSResources[] pool;
8886
private final int capacity;
89-
private final GPUInfoProvider gpuInfoProvider;
87+
private final GPUMemoryService gpuMemoryService;
9088
private int createdCount;
9189

9290
ReentrantLock lock = new ReentrantLock();
9391
Condition enoughResourcesCondition = lock.newCondition();
9492

95-
public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
93+
PoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) {
9694
if (capacity < 1 || capacity > MAX_RESOURCES) {
9795
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
9896
}
9997
this.capacity = capacity;
100-
this.gpuInfoProvider = gpuInfoProvider;
98+
this.gpuMemoryService = gpuMemoryService;
10199
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
102100
}
103101

104102
private ManagedCuVSResources getResourceFromPool() {
105103
for (int i = 0; i < createdCount; ++i) {
106104
var res = pool[i];
107-
if (res.locked == false) {
105+
if (res.isLocked() == false) {
108106
return res;
109107
}
110108
}
@@ -120,43 +118,45 @@ private int numLockedResources() {
120118
int lockedResources = 0;
121119
for (int i = 0; i < createdCount; ++i) {
122120
var res = pool[i];
123-
if (res.locked) {
121+
if (res.isLocked()) {
124122
lockedResources++;
125123
}
126124
}
127125
return lockedResources;
128126
}
129127

130128
@Override
131-
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException {
129+
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams)
130+
throws InterruptedException {
132131
try {
133132
var started = System.nanoTime();
134133
lock.lock();
135134

136135
boolean allConditionsMet = false;
137136
ManagedCuVSResources res = null;
137+
138+
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType, cagraIndexParams);
139+
logger.debug(
140+
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
141+
numVectors,
142+
dims,
143+
dataType.name(),
144+
requiredMemoryInBytes
145+
);
146+
138147
while (allConditionsMet == false) {
139148
res = getResourceFromPool();
140149

141150
final boolean enoughMemory;
142151
if (res != null) {
143-
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType);
144-
logger.debug(
145-
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
146-
numVectors,
147-
dims,
148-
dataType.name(),
149-
requiredMemoryInBytes
150-
);
151-
152152
// Check immutable constraints
153-
long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes();
154-
if (requiredMemoryInBytes > totalDeviceMemoryInBytes) {
153+
long totalMemoryInBytes = gpuMemoryService.totalMemoryInBytes(res);
154+
if (requiredMemoryInBytes > totalMemoryInBytes) {
155155
String message = Strings.format(
156156
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]",
157157
numVectors,
158158
dims,
159-
totalDeviceMemoryInBytes
159+
totalMemoryInBytes
160160
);
161161
logger.error(message);
162162
throw new IllegalArgumentException(message);
@@ -169,9 +169,9 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
169169
}
170170

171171
// Check resources availability
172-
long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
173-
enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes;
174-
logger.debug("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes, enoughMemory);
172+
long availableMemoryInBytes = gpuMemoryService.availableMemoryInBytes(res);
173+
enoughMemory = requiredMemoryInBytes <= availableMemoryInBytes;
174+
logger.debug("Free device memory [{} B], enoughMemory[{}]", availableMemoryInBytes, enoughMemory);
175175
} else {
176176
logger.debug("No resources available in pool");
177177
enoughMemory = false;
@@ -184,19 +184,33 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
184184
}
185185
var elapsed = started - System.nanoTime();
186186
logger.debug("Resource acquired in [{}ms]", elapsed / 1_000_000.0);
187-
res.locked = true;
187+
gpuMemoryService.reserveMemory(requiredMemoryInBytes);
188+
res.lock(() -> gpuMemoryService.releaseMemory(requiredMemoryInBytes));
188189
return res;
189190
} finally {
190191
lock.unlock();
191192
}
192193
}
193194

194-
private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) {
195+
private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) {
195196
int elementTypeBytes = switch (dataType) {
196197
case FLOAT -> Float.BYTES;
197198
case INT, UINT -> Integer.BYTES;
198199
case BYTE -> Byte.BYTES;
199200
};
201+
202+
if (cagraIndexParams.getCagraGraphBuildAlgo() == CagraIndexParams.CagraGraphBuildAlgo.IVF_PQ
203+
&& cagraIndexParams.getCuVSIvfPqParams() != null
204+
&& cagraIndexParams.getCuVSIvfPqParams().getIndexParams() != null
205+
&& cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim() != 0) {
206+
// See https://docs.rapids.ai/api/cuvs/nightly/neighbors/ivfpq/#index-device-memory
207+
var pqDim = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim();
208+
var pqBits = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqBits();
209+
var numClusters = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getnLists();
210+
var approximatedIvfBytes = numVectors * (pqDim * (pqBits / 8.0) + elementTypeBytes) + (long) numClusters * Integer.BYTES;
211+
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * approximatedIvfBytes);
212+
}
213+
200214
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes);
201215
}
202216

@@ -217,8 +231,8 @@ public void release(ManagedCuVSResources resources) {
217231
logger.debug("Releasing resources to pool");
218232
try {
219233
lock.lock();
220-
assert resources.locked;
221-
resources.locked = false;
234+
assert resources.isLocked();
235+
resources.unlock();
222236
enoughResourcesCondition.signalAll();
223237
} finally {
224238
lock.unlock();
@@ -238,8 +252,9 @@ public void shutdown() {
238252
/** A managed resource. Cannot be closed. */
239253
final class ManagedCuVSResources implements CuVSResources {
240254

241-
final CuVSResources delegate;
242-
boolean locked = false;
255+
private final CuVSResources delegate;
256+
private static final Runnable NOT_LOCKED = () -> {};
257+
private Runnable unlockAction = NOT_LOCKED;
243258

244259
ManagedCuVSResources(CuVSResources resources) {
245260
this.delegate = resources;
@@ -269,5 +284,18 @@ public Path tempDirectory() {
269284
public String toString() {
270285
return "ManagedCuVSResources[delegate=" + delegate + "]";
271286
}
287+
288+
void lock(Runnable unlockAction) {
289+
this.unlockAction = unlockAction;
290+
}
291+
292+
void unlock() {
293+
unlockAction.run();
294+
unlockAction = NOT_LOCKED;
295+
}
296+
297+
boolean isLocked() {
298+
return unlockAction != NOT_LOCKED;
299+
}
272300
}
273301
}

0 commit comments

Comments
 (0)