Skip to content

Commit 204405b

Browse files
committed
Add tests + fixes
1 parent 2b0cfa5 commit 204405b

File tree

2 files changed

+108
-13
lines changed

2 files changed

+108
-13
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
7878
CuVSProvider.provider().gpuInfoProvider()
7979
);
8080

81-
final ManagedCuVSResources[] pool;
82-
final int capacity;
83-
final GPUInfoProvider gpuInfoProvider;
84-
int createdCount;
81+
private final ManagedCuVSResources[] pool;
82+
private final int capacity;
83+
private final GPUInfoProvider gpuInfoProvider;
84+
private int createdCount;
8585

8686
ReentrantLock lock = new ReentrantLock();
8787
Condition poolAvailableCondition = lock.newCondition();
@@ -121,21 +121,19 @@ public ManagedCuVSResources acquire(int numVectors, int dims) throws Interrupted
121121
}
122122

123123
// Check resources availability
124-
var resourcesInfo = gpuInfoProvider.getCurrentInfo(res);
125-
126124
// Memory
127125
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims);
128-
if (requiredMemoryInBytes > resourcesInfo.totalDeviceMemoryInBytes()) {
126+
if (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes()) {
129127
throw new IllegalArgumentException(
130128
Strings.format(
131129
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]",
132130
numVectors,
133131
dims,
134-
resourcesInfo.totalDeviceMemoryInBytes() / 1048576.0f
132+
gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes() / (1024L * 1024L)
135133
)
136134
);
137135
}
138-
while (requiredMemoryInBytes > resourcesInfo.freeDeviceMemoryInBytes()) {
136+
while (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes()) {
139137
enoughMemoryCondition.await();
140138
}
141139

@@ -199,6 +197,11 @@ public ScopedAccess access() {
199197
return delegate.access();
200198
}
201199

200+
@Override
201+
public int deviceId() {
202+
return delegate.deviceId();
203+
}
204+
202205
@Override
203206
public void close() {
204207
throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients");

x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,28 @@
1313
import com.nvidia.cuvs.GPUInfo;
1414
import com.nvidia.cuvs.GPUInfoProvider;
1515

16+
import org.elasticsearch.logging.LogManager;
17+
import org.elasticsearch.logging.Logger;
1618
import org.elasticsearch.test.ESTestCase;
1719

1820
import java.nio.file.Path;
21+
import java.util.ArrayList;
1922
import java.util.List;
2023
import java.util.concurrent.atomic.AtomicInteger;
2124
import java.util.concurrent.atomic.AtomicReference;
25+
import java.util.function.LongSupplier;
2226

2327
import static org.hamcrest.Matchers.anyOf;
2428
import static org.hamcrest.Matchers.containsString;
29+
import static org.hamcrest.Matchers.equalTo;
30+
import static org.hamcrest.Matchers.not;
2531

2632
public class CuVSResourceManagerTests extends ESTestCase {
2733

34+
private static final Logger log = LogManager.getLogger(CuVSResourceManagerTests.class);
35+
36+
public static final long TOTAL_DEVICE_MEMORY_IN_BYTES = 256L * 1024 * 1024;
37+
2838
public void testBasic() throws InterruptedException {
2939
var mgr = new MockPoolingCuVSResourceManager(2);
3040
var res1 = mgr.acquire(0, 0);
@@ -65,10 +75,52 @@ public void testBlocking() throws Exception {
6575
mgr.shutdown();
6676
}
6777

78+
public void testBlockingOnInsufficientMemory() throws Exception {
79+
var mgr = new MockPoolingCuVSResourceManager(2);
80+
var res1 = mgr.acquire(16 * 1024, 1024);
81+
82+
AtomicReference<CuVSResources> holder = new AtomicReference<>();
83+
Thread t = new Thread(() -> {
84+
try {
85+
var res2 = mgr.acquire((16 * 1024) + 1, 1024);
86+
holder.set(res2);
87+
} catch (InterruptedException e) {
88+
throw new AssertionError(e);
89+
}
90+
});
91+
t.start();
92+
Thread.sleep(1_000);
93+
assertNull(holder.get());
94+
mgr.release(res1);
95+
t.join();
96+
assertThat(holder.get().toString(), anyOf(containsString("id=0"), containsString("id=1")));
97+
mgr.shutdown();
98+
}
99+
100+
public void testNotBlockingOnSufficientMemory() throws Exception {
101+
var mgr = new MockPoolingCuVSResourceManager(2);
102+
var res1 = mgr.acquire(16 * 1024, 1024);
103+
104+
AtomicReference<CuVSResources> holder = new AtomicReference<>();
105+
Thread t = new Thread(() -> {
106+
try {
107+
var res2 = mgr.acquire((16 * 1024) - 1, 1024);
108+
holder.set(res2);
109+
} catch (InterruptedException e) {
110+
throw new AssertionError(e);
111+
}
112+
});
113+
t.start();
114+
t.join(5_000);
115+
assertNotNull(holder.get());
116+
assertThat(holder.get().toString(), not(equalTo(res1.toString())));
117+
mgr.shutdown();
118+
}
119+
68120
public void testManagedResIsNotClosable() throws Exception {
69121
var mgr = new MockPoolingCuVSResourceManager(1);
70122
var res = mgr.acquire(0, 0);
71-
assertThrows(UnsupportedOperationException.class, () -> res.close());
123+
assertThrows(UnsupportedOperationException.class, res::close);
72124
mgr.release(res);
73125
mgr.shutdown();
74126
}
@@ -85,16 +137,45 @@ public void testDoubleRelease() throws InterruptedException {
85137

86138
static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager {
87139

88-
final AtomicInteger idGenerator = new AtomicInteger();
140+
private final AtomicInteger idGenerator = new AtomicInteger();
141+
private final List<Long> allocations;
89142

90143
MockPoolingCuVSResourceManager(int capacity) {
91-
super(capacity, new MockGPUInfoProvider());
144+
this(capacity, new ArrayList<>());
145+
}
146+
147+
private MockPoolingCuVSResourceManager(int capacity, List<Long> allocationList) {
148+
super(capacity, new MockGPUInfoProvider(() -> freeMemoryFunction(allocationList)));
149+
this.allocations = allocationList;
150+
}
151+
152+
private static long freeMemoryFunction(List<Long> allocations) {
153+
return TOTAL_DEVICE_MEMORY_IN_BYTES - allocations.stream().mapToLong(x -> x).sum();
92154
}
93155

94156
@Override
95157
protected CuVSResources createNew() {
96158
return new MockCuVSResources(idGenerator.getAndIncrement());
97159
}
160+
161+
@Override
162+
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
163+
var res = super.acquire(numVectors, dims);
164+
long memory = (long)(numVectors * dims * Float.BYTES *
165+
CuVSResourceManager.PoolingCuVSResourceManager.GPU_COMPUTATION_MEMORY_FACTOR);
166+
allocations.add(memory);
167+
log.info("Added [{}]", memory);
168+
return res;
169+
}
170+
171+
@Override
172+
public void release(ManagedCuVSResources resources) {
173+
if (allocations.isEmpty() == false) {
174+
var x = allocations.removeLast();
175+
log.info("Removed [{}]", x);
176+
}
177+
super.release(resources);
178+
}
98179
}
99180

100181
static class MockCuVSResources implements CuVSResources {
@@ -110,6 +191,11 @@ public ScopedAccess access() {
110191
throw new UnsupportedOperationException();
111192
}
112193

194+
@Override
195+
public int deviceId() {
196+
return 0;
197+
}
198+
113199
@Override
114200
public void close() {}
115201

@@ -125,6 +211,12 @@ public String toString() {
125211
}
126212

127213
private static class MockGPUInfoProvider implements GPUInfoProvider {
214+
private final LongSupplier freeMemorySupplier;
215+
216+
MockGPUInfoProvider(LongSupplier freeMemorySupplier) {
217+
this.freeMemorySupplier = freeMemorySupplier;
218+
}
219+
128220
@Override
129221
public List<GPUInfo> availableGPUs() throws Throwable {
130222
throw new UnsupportedOperationException();
@@ -137,7 +229,7 @@ public List<GPUInfo> compatibleGPUs() throws Throwable {
137229

138230
@Override
139231
public CuVSResourcesInfo getCurrentInfo(CuVSResources cuVSResources) {
140-
return new CuVSResourcesInfo(256L * 1024 * 1024, 2048L * 1024 * 1024);
232+
return new CuVSResourcesInfo(freeMemorySupplier.getAsLong(), TOTAL_DEVICE_MEMORY_IN_BYTES);
141233
}
142234
}
143235
}

0 commit comments

Comments
 (0)