Skip to content

Commit 97d465e

Browse files
committed
PR comments: added javadoc, fixed unlock issue (+ test for it)
1 parent 8416562 commit 97d465e

File tree

5 files changed

+47
-1
lines changed

5 files changed

+47
-1
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ void lock(Runnable unlockAction) {
290290
}
291291

292292
void unlock() {
293+
unlockAction.run();
293294
unlockAction = NOT_LOCKED;
294295
}
295296

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUMemoryService.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
import com.nvidia.cuvs.CuVSResources;
1111

12+
/**
13+
* Abstracts GPU memory tracking (total vs available)
14+
*/
1215
interface GPUMemoryService {
1316

1417
long totalMemoryInBytes(CuVSResources res);

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/RealGPUMemoryService.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import com.nvidia.cuvs.CuVSResources;
1111
import com.nvidia.cuvs.GPUInfoProvider;
1212

13+
/**
14+
* A {@link GPUMemoryService} that tracks how much memory is currently used/available on a GPU by using the GPU free/total memory APIs
15+
* (via a {@link GPUInfoProvider})
16+
*/
1317
class RealGPUMemoryService implements GPUMemoryService {
1418
private final GPUInfoProvider gpuInfoProvider;
1519

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/TrackingGPUMemoryService.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
import com.nvidia.cuvs.CuVSResources;
1111

12+
/**
13+
* A {@link GPUMemoryService} that tracks manually how much memory is currently estimated to be used/available on a GPU.
14+
* This implementation is useful when we are not able to use a "Real memory" measurement; for example, if we are using pooled RMM memory,
15+
* the pool will permanently occupy most of the GPU RAM, allocations will happen inside the pool, and the "Real memory" measurement API
16+
* will always report a (tiny) fixed amount of free memory (whatever is not in the pool).
17+
*/
1218
class TrackingGPUMemoryService implements GPUMemoryService {
1319

1420
private final long totalMemoryInBytes;

x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.hamcrest.Matchers.anyOf;
2525
import static org.hamcrest.Matchers.containsString;
2626
import static org.hamcrest.Matchers.equalTo;
27+
import static org.hamcrest.Matchers.lessThan;
2728
import static org.hamcrest.Matchers.not;
2829

2930
public class CuVSResourceManagerTests extends ESTestCase {
@@ -57,6 +58,27 @@ public void testBasicWithIvfPq() throws InterruptedException {
5758
testBasic(createIvfPqParams());
5859
}
5960

61+
public void testMultipleAcquireRelease() throws InterruptedException {
62+
var mgr = new MockPoolingCuVSResourceManager(2);
63+
var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createNnDescentParams());
64+
var res2 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createIvfPqParams());
65+
assertThat(res1.toString(), containsString("id=0"));
66+
assertThat(res2.toString(), containsString("id=1"));
67+
assertThat(mgr.availableMemory(), lessThan(TOTAL_DEVICE_MEMORY_IN_BYTES / 2));
68+
mgr.release(res1);
69+
mgr.release(res2);
70+
assertThat(mgr.availableMemory(), equalTo(TOTAL_DEVICE_MEMORY_IN_BYTES));
71+
res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createNnDescentParams());
72+
res2 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createIvfPqParams());
73+
assertThat(res1.toString(), containsString("id=0"));
74+
assertThat(res2.toString(), containsString("id=1"));
75+
assertThat(mgr.availableMemory(), lessThan(TOTAL_DEVICE_MEMORY_IN_BYTES / 2));
76+
mgr.release(res1);
77+
mgr.release(res2);
78+
assertThat(mgr.availableMemory(), equalTo(TOTAL_DEVICE_MEMORY_IN_BYTES));
79+
mgr.shutdown();
80+
}
81+
6082
private static void testBlocking(CagraIndexParams params) throws Exception {
6183
var mgr = new MockPoolingCuVSResourceManager(2);
6284
var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params);
@@ -185,13 +207,23 @@ private static CagraIndexParams createIvfPqParams() {
185207
static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager {
186208

187209
private final AtomicInteger idGenerator = new AtomicInteger();
210+
private final GPUMemoryService gpuMemoryService;
188211

189212
MockPoolingCuVSResourceManager(int capacity) {
190213
this(capacity, TOTAL_DEVICE_MEMORY_IN_BYTES);
191214
}
192215

193216
MockPoolingCuVSResourceManager(int capacity, long totalMemoryInBytes) {
194-
super(capacity, new TrackingGPUMemoryService(totalMemoryInBytes));
217+
this(capacity, new TrackingGPUMemoryService(totalMemoryInBytes));
218+
}
219+
220+
private MockPoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) {
221+
super(capacity, gpuMemoryService);
222+
this.gpuMemoryService = gpuMemoryService;
223+
}
224+
225+
long availableMemory() {
226+
return gpuMemoryService.availableMemoryInBytes(null);
195227
}
196228

197229
@Override

0 commit comments

Comments
 (0)