|
24 | 24 | import static org.hamcrest.Matchers.anyOf; |
25 | 25 | import static org.hamcrest.Matchers.containsString; |
26 | 26 | import static org.hamcrest.Matchers.equalTo; |
| 27 | +import static org.hamcrest.Matchers.lessThan; |
27 | 28 | import static org.hamcrest.Matchers.not; |
28 | 29 |
|
29 | 30 | public class CuVSResourceManagerTests extends ESTestCase { |
@@ -57,6 +58,27 @@ public void testBasicWithIvfPq() throws InterruptedException { |
57 | 58 | testBasic(createIvfPqParams()); |
58 | 59 | } |
59 | 60 |
|
| 61 | + public void testMultipleAcquireRelease() throws InterruptedException { |
| 62 | + var mgr = new MockPoolingCuVSResourceManager(2); |
| 63 | + var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createNnDescentParams()); |
| 64 | + var res2 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createIvfPqParams()); |
| 65 | + assertThat(res1.toString(), containsString("id=0")); |
| 66 | + assertThat(res2.toString(), containsString("id=1")); |
| 67 | + assertThat(mgr.availableMemory(), lessThan(TOTAL_DEVICE_MEMORY_IN_BYTES / 2)); |
| 68 | + mgr.release(res1); |
| 69 | + mgr.release(res2); |
| 70 | + assertThat(mgr.availableMemory(), equalTo(TOTAL_DEVICE_MEMORY_IN_BYTES)); |
| 71 | + res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createNnDescentParams()); |
| 72 | + res2 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT, createIvfPqParams()); |
| 73 | + assertThat(res1.toString(), containsString("id=0")); |
| 74 | + assertThat(res2.toString(), containsString("id=1")); |
| 75 | + assertThat(mgr.availableMemory(), lessThan(TOTAL_DEVICE_MEMORY_IN_BYTES / 2)); |
| 76 | + mgr.release(res1); |
| 77 | + mgr.release(res2); |
| 78 | + assertThat(mgr.availableMemory(), equalTo(TOTAL_DEVICE_MEMORY_IN_BYTES)); |
| 79 | + mgr.shutdown(); |
| 80 | + } |
| 81 | + |
60 | 82 | private static void testBlocking(CagraIndexParams params) throws Exception { |
61 | 83 | var mgr = new MockPoolingCuVSResourceManager(2); |
62 | 84 | var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT, params); |
@@ -185,13 +207,23 @@ private static CagraIndexParams createIvfPqParams() { |
185 | 207 | static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager { |
186 | 208 |
|
187 | 209 | private final AtomicInteger idGenerator = new AtomicInteger(); |
| 210 | + private final GPUMemoryService gpuMemoryService; |
188 | 211 |
|
189 | 212 | MockPoolingCuVSResourceManager(int capacity) { |
190 | 213 | this(capacity, TOTAL_DEVICE_MEMORY_IN_BYTES); |
191 | 214 | } |
192 | 215 |
|
193 | 216 | MockPoolingCuVSResourceManager(int capacity, long totalMemoryInBytes) { |
194 | | - super(capacity, new TrackingGPUMemoryService(totalMemoryInBytes)); |
| 217 | + this(capacity, new TrackingGPUMemoryService(totalMemoryInBytes)); |
| 218 | + } |
| 219 | + |
| 220 | + private MockPoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) { |
| 221 | + super(capacity, gpuMemoryService); |
| 222 | + this.gpuMemoryService = gpuMemoryService; |
| 223 | + } |
| 224 | + |
| 225 | + long availableMemory() { |
| 226 | + return gpuMemoryService.availableMemoryInBytes(null); |
195 | 227 | } |
196 | 228 |
|
197 | 229 | @Override |
|
0 commit comments