Skip to content

Commit 5050166

Browse files
committed
Abstract GPU memory tracking
1 parent 6f90938 commit 5050166

File tree

6 files changed

+219
-107
lines changed

6 files changed

+219
-107
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
package org.elasticsearch.xpack.gpu.codec;
99

10+
import com.nvidia.cuvs.CagraIndexParams;
1011
import com.nvidia.cuvs.CuVSMatrix;
1112
import com.nvidia.cuvs.CuVSResources;
12-
import com.nvidia.cuvs.GPUInfoProvider;
1313
import com.nvidia.cuvs.spi.CuVSProvider;
1414

1515
import org.elasticsearch.core.Strings;
@@ -47,10 +47,12 @@ public interface CuVSResourceManager {
4747
* effect on GPU memory and compute usage to determine whether to give out
4848
* another resource or wait for a resources to be returned before giving out another.
4949
*/
50-
// numVectors and dims are currently unused, but could be used along with GPU metadata,
51-
// memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
52-
// to give out a resources or not.
53-
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException;
50+
ManagedCuVSResources acquire(
51+
int numVectors,
52+
int dims,
53+
CuVSMatrix.DataType dataType,
54+
CagraIndexParams.CagraGraphBuildAlgo graphBuildAlgo
55+
) throws InterruptedException;
5456

5557
/** Marks the resources as finished with regard to compute. */
5658
void finishedComputation(ManagedCuVSResources resources);
@@ -80,31 +82,31 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
8082
static class Holder {
8183
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
8284
MAX_RESOURCES,
83-
CuVSProvider.provider().gpuInfoProvider()
85+
new RealGPUMemoryService(CuVSProvider.provider().gpuInfoProvider())
8486
);
8587
}
8688

8789
private final ManagedCuVSResources[] pool;
8890
private final int capacity;
89-
private final GPUInfoProvider gpuInfoProvider;
91+
private final GPUMemoryService gpuMemoryService;
9092
private int createdCount;
9193

9294
ReentrantLock lock = new ReentrantLock();
9395
Condition enoughResourcesCondition = lock.newCondition();
9496

95-
public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
97+
PoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) {
9698
if (capacity < 1 || capacity > MAX_RESOURCES) {
9799
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
98100
}
99101
this.capacity = capacity;
100-
this.gpuInfoProvider = gpuInfoProvider;
102+
this.gpuMemoryService = gpuMemoryService;
101103
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
102104
}
103105

104106
private ManagedCuVSResources getResourceFromPool() {
105107
for (int i = 0; i < createdCount; ++i) {
106108
var res = pool[i];
107-
if (res.locked == false) {
109+
if (res.isLocked() == false) {
108110
return res;
109111
}
110112
}
@@ -120,43 +122,49 @@ private int numLockedResources() {
120122
int lockedResources = 0;
121123
for (int i = 0; i < createdCount; ++i) {
122124
var res = pool[i];
123-
if (res.locked) {
125+
if (res.isLocked()) {
124126
lockedResources++;
125127
}
126128
}
127129
return lockedResources;
128130
}
129131

130132
@Override
131-
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException {
133+
public ManagedCuVSResources acquire(
134+
int numVectors,
135+
int dims,
136+
CuVSMatrix.DataType dataType,
137+
CagraIndexParams.CagraGraphBuildAlgo graphBuildAlgo
138+
) throws InterruptedException {
132139
try {
133140
var started = System.nanoTime();
134141
lock.lock();
135142

136143
boolean allConditionsMet = false;
137144
ManagedCuVSResources res = null;
145+
146+
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType, graphBuildAlgo);
147+
logger.debug(
148+
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
149+
numVectors,
150+
dims,
151+
dataType.name(),
152+
requiredMemoryInBytes
153+
);
154+
138155
while (allConditionsMet == false) {
139156
res = getResourceFromPool();
140157

141158
final boolean enoughMemory;
142159
if (res != null) {
143-
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType);
144-
logger.debug(
145-
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
146-
numVectors,
147-
dims,
148-
dataType.name(),
149-
requiredMemoryInBytes
150-
);
151-
152160
// Check immutable constraints
153-
long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes();
154-
if (requiredMemoryInBytes > totalDeviceMemoryInBytes) {
161+
long totalMemoryInBytes = gpuMemoryService.totalMemoryInBytes(res);
162+
if (requiredMemoryInBytes > totalMemoryInBytes) {
155163
String message = Strings.format(
156164
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]",
157165
numVectors,
158166
dims,
159-
totalDeviceMemoryInBytes
167+
totalMemoryInBytes
160168
);
161169
logger.error(message);
162170
throw new IllegalArgumentException(message);
@@ -169,9 +177,9 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
169177
}
170178

171179
// Check resources availability
172-
long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
173-
enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes;
174-
logger.debug("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes, enoughMemory);
180+
long availableMemoryInBytes = gpuMemoryService.availableMemoryInBytes(res);
181+
enoughMemory = requiredMemoryInBytes <= availableMemoryInBytes;
182+
logger.debug("Free device memory [{} B], enoughMemory[{}]", availableMemoryInBytes, enoughMemory);
175183
} else {
176184
logger.debug("No resources available in pool");
177185
enoughMemory = false;
@@ -184,14 +192,20 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
184192
}
185193
var elapsed = started - System.nanoTime();
186194
logger.debug("Resource acquired in [{}ms]", elapsed / 1_000_000.0);
187-
res.locked = true;
195+
gpuMemoryService.reserveMemory(requiredMemoryInBytes);
196+
res.lock(() -> gpuMemoryService.releaseMemory(requiredMemoryInBytes));
188197
return res;
189198
} finally {
190199
lock.unlock();
191200
}
192201
}
193202

194-
private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) {
203+
private long estimateRequiredMemory(
204+
int numVectors,
205+
int dims,
206+
CuVSMatrix.DataType dataType,
207+
CagraIndexParams.CagraGraphBuildAlgo graphBuildAlgo
208+
) {
195209
int elementTypeBytes = switch (dataType) {
196210
case FLOAT -> Float.BYTES;
197211
case INT, UINT -> Integer.BYTES;
@@ -217,8 +231,8 @@ public void release(ManagedCuVSResources resources) {
217231
logger.debug("Releasing resources to pool");
218232
try {
219233
lock.lock();
220-
assert resources.locked;
221-
resources.locked = false;
234+
assert resources.isLocked();
235+
resources.unlock();
222236
enoughResourcesCondition.signalAll();
223237
} finally {
224238
lock.unlock();
@@ -238,8 +252,9 @@ public void shutdown() {
238252
/** A managed resource. Cannot be closed. */
239253
final class ManagedCuVSResources implements CuVSResources {
240254

241-
final CuVSResources delegate;
242-
boolean locked = false;
255+
private final CuVSResources delegate;
256+
private static final Runnable NOT_LOCKED = () -> {};
257+
private Runnable unlockAction = NOT_LOCKED;
243258

244259
ManagedCuVSResources(CuVSResources resources) {
245260
this.delegate = resources;
@@ -269,5 +284,17 @@ public Path tempDirectory() {
269284
public String toString() {
270285
return "ManagedCuVSResources[delegate=" + delegate + "]";
271286
}
287+
288+
void lock(Runnable unlockAction) {
289+
this.unlockAction = unlockAction;
290+
}
291+
292+
void unlock() {
293+
unlockAction = NOT_LOCKED;
294+
}
295+
296+
boolean isLocked() {
297+
return unlockAction != NOT_LOCKED;
298+
}
272299
}
273300
}

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,12 @@ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IO
193193
try (
194194
var resourcesHolder = new ResourcesHolder(
195195
cuVSResourceManager,
196-
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), CuVSMatrix.DataType.FLOAT)
196+
cuVSResourceManager.acquire(
197+
numVectors,
198+
fieldInfo.getVectorDimension(),
199+
CuVSMatrix.DataType.FLOAT,
200+
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
201+
)
197202
)
198203
) {
199204
var builder = CuVSMatrix.deviceBuilder(
@@ -533,7 +538,12 @@ private void mergeByteVectorField(
533538
var dataset = DatasetUtilsImpl.fromMemorySegment(packedSegment, numVectors, packedRowSize, dataType);
534539
var resourcesHolder = new ResourcesHolder(
535540
cuVSResourceManager,
536-
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType)
541+
cuVSResourceManager.acquire(
542+
numVectors,
543+
fieldInfo.getVectorDimension(),
544+
dataType,
545+
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
546+
)
537547
)
538548
) {
539549
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
@@ -557,7 +567,12 @@ private void mergeByteVectorField(
557567
var dataset = builder.build();
558568
var resourcesHolder = new ResourcesHolder(
559569
cuVSResourceManager,
560-
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType)
570+
cuVSResourceManager.acquire(
571+
numVectors,
572+
fieldInfo.getVectorDimension(),
573+
dataType,
574+
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
575+
)
561576
)
562577
) {
563578
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
@@ -578,7 +593,12 @@ private void mergeByteVectorField(
578593
var dataset = builder.build();
579594
var resourcesHolder = new ResourcesHolder(
580595
cuVSResourceManager,
581-
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType)
596+
cuVSResourceManager.acquire(
597+
numVectors,
598+
fieldInfo.getVectorDimension(),
599+
dataType,
600+
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
601+
)
582602
)
583603
) {
584604
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
@@ -605,7 +625,12 @@ private void mergeFloatVectorField(
605625
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
606626
var resourcesHolder = new ResourcesHolder(
607627
cuVSResourceManager,
608-
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType)
628+
cuVSResourceManager.acquire(
629+
numVectors,
630+
fieldInfo.getVectorDimension(),
631+
dataType,
632+
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
633+
)
609634
)
610635
) {
611636
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
@@ -628,7 +653,12 @@ private void mergeFloatVectorField(
628653
var dataset = builder.build();
629654
var resourcesHolder = new ResourcesHolder(
630655
cuVSResourceManager,
631-
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType)
656+
cuVSResourceManager.acquire(
657+
numVectors,
658+
fieldInfo.getVectorDimension(),
659+
dataType,
660+
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
661+
)
632662
)
633663
) {
634664
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
@@ -650,7 +680,12 @@ private void mergeFloatVectorField(
650680
var dataset = builder.build();
651681
var resourcesHolder = new ResourcesHolder(
652682
cuVSResourceManager,
653-
cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType)
683+
cuVSResourceManager.acquire(
684+
numVectors,
685+
fieldInfo.getVectorDimension(),
686+
dataType,
687+
CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT
688+
)
654689
)
655690
) {
656691
generateGpuGraphAndWriteMeta(resourcesHolder, fieldInfo, dataset);
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import com.nvidia.cuvs.CuVSResources;
11+
12+
interface GPUMemoryService {
13+
14+
long totalMemoryInBytes(CuVSResources res);
15+
16+
long availableMemoryInBytes(CuVSResources res);
17+
18+
void reserveMemory(long memoryInBytes);
19+
20+
void releaseMemory(long memoryInBytes);
21+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import com.nvidia.cuvs.CuVSResources;
11+
import com.nvidia.cuvs.GPUInfoProvider;
12+
13+
class RealGPUMemoryService implements GPUMemoryService {
14+
private final GPUInfoProvider gpuInfoProvider;
15+
16+
RealGPUMemoryService(GPUInfoProvider gpuInfoProvider) {
17+
this.gpuInfoProvider = gpuInfoProvider;
18+
}
19+
20+
@Override
21+
public long totalMemoryInBytes(CuVSResources res) {
22+
return gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes();
23+
}
24+
25+
@Override
26+
public long availableMemoryInBytes(CuVSResources res) {
27+
return gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
28+
}
29+
30+
@Override
31+
public void reserveMemory(long memoryInBytes) {
32+
// No-op
33+
}
34+
35+
@Override
36+
public void releaseMemory(long memoryInBytes) {
37+
// No-op
38+
}
39+
}

0 commit comments

Comments
 (0)