@@ -84,7 +84,6 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
8484 private int createdCount ;
8585
8686 ReentrantLock lock = new ReentrantLock ();
87- Condition poolAvailableCondition = lock .newCondition ();
8887 Condition enoughMemoryCondition = lock .newCondition ();
8988
9089 public PoolingCuVSResourceManager (int capacity , GPUInfoProvider gpuInfoProvider ) {
@@ -115,28 +114,37 @@ private ManagedCuVSResources getResourceFromPool() {
115114 public ManagedCuVSResources acquire (int numVectors , int dims ) throws InterruptedException {
116115 try {
117116 lock .lock ();
118- ManagedCuVSResources res ;
119- while ((res = getResourceFromPool ()) == null ) {
120- poolAvailableCondition .await ();
121- }
122117
123- // Check resources availability
124- // Memory
125- long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims );
126- if (requiredMemoryInBytes > gpuInfoProvider .getCurrentInfo (res ).totalDeviceMemoryInBytes ()) {
127- throw new IllegalArgumentException (
128- Strings .format (
129- "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]" ,
130- numVectors ,
131- dims ,
132- gpuInfoProvider .getCurrentInfo (res ).totalDeviceMemoryInBytes () / (1024L * 1024L )
133- )
134- );
135- }
136- while (requiredMemoryInBytes > gpuInfoProvider .getCurrentInfo (res ).freeDeviceMemoryInBytes ()) {
137- enoughMemoryCondition .await ();
118+ boolean allConditionsMet = false ;
119+ ManagedCuVSResources res = null ;
120+ while (allConditionsMet == false ) {
121+ res = getResourceFromPool ();
122+ final boolean enoughMemory ;
123+ if (res != null ) {
124+ // Check resources availability
125+ // Memory
126+ long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims );
127+ if (requiredMemoryInBytes > gpuInfoProvider .getCurrentInfo (res ).totalDeviceMemoryInBytes ()) {
128+ throw new IllegalArgumentException (
129+ Strings .format (
130+ "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]" ,
131+ numVectors ,
132+ dims ,
133+ gpuInfoProvider .getCurrentInfo (res ).totalDeviceMemoryInBytes () / (1024L * 1024L )
134+ )
135+ );
136+ }
137+ enoughMemory = requiredMemoryInBytes <= gpuInfoProvider .getCurrentInfo (res ).freeDeviceMemoryInBytes ();
138+ } else {
139+ enoughMemory = false ;
140+ }
141+ if (enoughMemory == false ) {
142+ enoughMemoryCondition .await ();
143+ }
144+
145+ // TODO: add enoughComputation / enoughComputationCondition here
146+ allConditionsMet = enoughMemory ; // && enoughComputation
138147 }
139-
140148 res .locked = true ;
141149 return res ;
142150 } finally {
@@ -165,8 +173,7 @@ public void release(ManagedCuVSResources resources) {
165173 lock .lock ();
166174 assert resources .locked ;
167175 resources .locked = false ;
168- poolAvailableCondition .signalAll ();
169- enoughMemoryCondition .signalAll ();
176+ enoughMemoryCondition .signal ();
170177 } finally {
171178 lock .unlock ();
172179 }
0 commit comments