Skip to content

Commit 2cf0388

Browse files
committed
Fix: re-acquire res before re-evaluating condition(s)
1 parent 204405b commit 2cf0388

File tree

1 file changed

+30
-23
lines changed

1 file changed

+30
-23
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
8484
private int createdCount;
8585

8686
ReentrantLock lock = new ReentrantLock();
87-
Condition poolAvailableCondition = lock.newCondition();
8887
Condition enoughMemoryCondition = lock.newCondition();
8988

9089
public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
@@ -115,28 +114,37 @@ private ManagedCuVSResources getResourceFromPool() {
115114
public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException {
116115
try {
117116
lock.lock();
118-
ManagedCuVSResources res;
119-
while ((res = getResourceFromPool()) == null) {
120-
poolAvailableCondition.await();
121-
}
122117

123-
// Check resources availability
124-
// Memory
125-
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims);
126-
if (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes()) {
127-
throw new IllegalArgumentException(
128-
Strings.format(
129-
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]",
130-
numVectors,
131-
dims,
132-
gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes() / (1024L * 1024L)
133-
)
134-
);
135-
}
136-
while (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes()) {
137-
enoughMemoryCondition.await();
118+
boolean allConditionsMet = false;
119+
ManagedCuVSResources res = null;
120+
while (allConditionsMet == false) {
121+
res = getResourceFromPool();
122+
final boolean enoughMemory;
123+
if (res != null) {
124+
// Check resources availability
125+
// Memory
126+
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims);
127+
if (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes()) {
128+
throw new IllegalArgumentException(
129+
Strings.format(
130+
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]",
131+
numVectors,
132+
dims,
133+
gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes() / (1024L * 1024L)
134+
)
135+
);
136+
}
137+
enoughMemory = requiredMemoryInBytes <= gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
138+
} else {
139+
enoughMemory = false;
140+
}
141+
if (enoughMemory == false) {
142+
enoughMemoryCondition.await();
143+
}
144+
145+
// TODO: add enoughComputation / enoughComputationCondition here
146+
allConditionsMet = enoughMemory; // && enoughComputation
138147
}
139-
140148
res.locked = true;
141149
return res;
142150
} finally {
@@ -165,8 +173,7 @@ public void release(ManagedCuVSResources resources) {
165173
lock.lock();
166174
assert resources.locked;
167175
resources.locked = false;
168-
poolAvailableCondition.signalAll();
169-
enoughMemoryCondition.signalAll();
176+
enoughMemoryCondition.signal();
170177
} finally {
171178
lock.unlock();
172179
}

0 commit comments

Comments
 (0)