Skip to content

Commit 2b0cfa5

Browse files
committed
PoolingCuVSResourceManager with memory availability
1 parent f452dfc commit 2b0cfa5

File tree

2 files changed

+107
-21
lines changed

2 files changed

+107
-21
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,17 @@
99

1010
import com.nvidia.cuvs.CuVSResources;
1111

12+
import com.nvidia.cuvs.GPUInfoProvider;
13+
14+
import com.nvidia.cuvs.spi.CuVSProvider;
15+
16+
import org.elasticsearch.core.Strings;
1217
import org.elasticsearch.xpack.gpu.GPUSupport;
1318

1419
import java.nio.file.Path;
1520
import java.util.Objects;
16-
import java.util.concurrent.ArrayBlockingQueue;
17-
import java.util.concurrent.BlockingQueue;
21+
import java.util.concurrent.locks.Condition;
22+
import java.util.concurrent.locks.ReentrantLock;
1823

1924
/**
2025
* A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU.
@@ -65,35 +70,84 @@ static CuVSResourceManager pooling() {
6570
*/
6671
class PoolingCuVSResourceManager implements CuVSResourceManager {
6772

73+
/** A multiplier on input data to account for intermediate and output data size required while processing it */
74+
static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;
6875
static final int MAX_RESOURCES = 2;
69-
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(MAX_RESOURCES);
76+
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
77+
MAX_RESOURCES,
78+
CuVSProvider.provider().gpuInfoProvider()
79+
);
7080

71-
final BlockingQueue<ManagedCuVSResources> pool;
81+
final ManagedCuVSResources[] pool;
7282
final int capacity;
83+
final GPUInfoProvider gpuInfoProvider;
7384
int createdCount;
7485

75-
public PoolingCuVSResourceManager(int capacity) {
86+
ReentrantLock lock = new ReentrantLock();
87+
Condition poolAvailableCondition = lock.newCondition();
88+
Condition enoughMemoryCondition = lock.newCondition();
89+
90+
public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
7691
if (capacity < 1 || capacity > MAX_RESOURCES) {
7792
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
7893
}
7994
this.capacity = capacity;
80-
this.pool = new ArrayBlockingQueue<>(capacity);
95+
this.gpuInfoProvider = gpuInfoProvider;
96+
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
8197
}
8298

83-
@Override
84-
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
85-
ManagedCuVSResources res = pool.poll();
86-
if (res != null) {
99+
private ManagedCuVSResources getResourceFromPool() {
100+
for (int i = 0; i < createdCount; ++i) {
101+
var res = pool[i];
102+
if (res.locked == false) {
103+
return res;
104+
}
105+
}
106+
if (createdCount < capacity) {
107+
var res = new ManagedCuVSResources(Objects.requireNonNull(createNew()));
108+
pool[createdCount++] = res;
87109
return res;
88110
}
89-
synchronized (this) {
90-
if (createdCount < capacity) {
91-
createdCount++;
92-
return new ManagedCuVSResources(Objects.requireNonNull(createNew()));
111+
return null;
112+
}
113+
114+
@Override
115+
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
116+
try {
117+
lock.lock();
118+
ManagedCuVSResources res;
119+
while ((res = getResourceFromPool()) == null) {
120+
poolAvailableCondition.await();
121+
}
122+
123+
// Check resources availability
124+
var resourcesInfo = gpuInfoProvider.getCurrentInfo(res);
125+
126+
// Memory
127+
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims);
128+
if (requiredMemoryInBytes > resourcesInfo.totalDeviceMemoryInBytes()) {
129+
throw new IllegalArgumentException(
130+
Strings.format(
131+
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]",
132+
numVectors,
133+
dims,
134+
resourcesInfo.totalDeviceMemoryInBytes() / 1048576.0f
135+
)
136+
);
137+
}
138+
while (requiredMemoryInBytes > resourcesInfo.freeDeviceMemoryInBytes()) {
139+
enoughMemoryCondition.await();
93140
}
141+
142+
res.locked = true;
143+
return res;
144+
} finally {
145+
lock.unlock();
94146
}
95-
// Otherwise, wait for one to be released
96-
return pool.take();
147+
}
148+
149+
private long estimateRequiredMemory(int numVectors, int dims) {
150+
return (long)(GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * Float.BYTES);
97151
}
98152

99153
// visible for testing
@@ -104,27 +158,37 @@ protected CuVSResources createNew() {
104158
@Override
105159
public void finishedComputation(ManagedCuVSResources resources) {
106160
// currently does nothing, but could allow acquire to return possibly blocked resources
161+
// something like enoughComputationCondition.signal()?
107162
}
108163

109164
@Override
110165
public void release(ManagedCuVSResources resources) {
111-
var added = pool.offer(Objects.requireNonNull(resources));
112-
assert added : "Failed to release resource back to pool";
166+
try {
167+
lock.lock();
168+
assert resources.locked;
169+
resources.locked = false;
170+
poolAvailableCondition.signalAll();
171+
enoughMemoryCondition.signalAll();
172+
} finally {
173+
lock.unlock();
174+
}
113175
}
114176

115177
@Override
116178
public void shutdown() {
117-
for (ManagedCuVSResources res : pool) {
179+
for (int i = 0; i < createdCount; ++i) {
180+
var res = pool[i];
181+
assert res != null;
118182
res.delegate.close();
119183
}
120-
pool.clear();
121184
}
122185
}
123186

124187
/** A managed resource. Cannot be closed. */
125188
final class ManagedCuVSResources implements CuVSResources {
126189

127190
final CuVSResources delegate;
191+
boolean locked = false;
128192

129193
ManagedCuVSResources(CuVSResources resources) {
130194
this.delegate = resources;

x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99

1010
import com.nvidia.cuvs.CuVSResources;
1111

12+
import com.nvidia.cuvs.CuVSResourcesInfo;
13+
import com.nvidia.cuvs.GPUInfo;
14+
import com.nvidia.cuvs.GPUInfoProvider;
15+
1216
import org.elasticsearch.test.ESTestCase;
1317

1418
import java.nio.file.Path;
19+
import java.util.List;
1520
import java.util.concurrent.atomic.AtomicInteger;
1621
import java.util.concurrent.atomic.AtomicReference;
1722

@@ -83,7 +88,7 @@ static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingC
8388
final AtomicInteger idGenerator = new AtomicInteger();
8489

8590
MockPoolingCuVSResourceManager(int capacity) {
86-
super(capacity);
91+
super(capacity, new MockGPUInfoProvider());
8792
}
8893

8994
@Override
@@ -118,4 +123,21 @@ public String toString() {
118123
return "MockCuVSResources[id=" + id + "]";
119124
}
120125
}
126+
127+
private static class MockGPUInfoProvider implements GPUInfoProvider {
128+
@Override
129+
public List<GPUInfo> availableGPUs() throws Throwable {
130+
throw new UnsupportedOperationException();
131+
}
132+
133+
@Override
134+
public List<GPUInfo> compatibleGPUs() throws Throwable {
135+
throw new UnsupportedOperationException();
136+
}
137+
138+
@Override
139+
public CuVSResourcesInfo getCurrentInfo(CuVSResources cuVSResources) {
140+
return new CuVSResourcesInfo(256L * 1024 * 1024, 2048L * 1024 * 1024);
141+
}
142+
}
121143
}

0 commit comments

Comments
 (0)