77
88package org .elasticsearch .xpack .gpu .codec ;
99
10+ import com .nvidia .cuvs .CagraIndexParams ;
1011import com .nvidia .cuvs .CuVSMatrix ;
1112import com .nvidia .cuvs .CuVSResources ;
12- import com .nvidia .cuvs .GPUInfoProvider ;
1313import com .nvidia .cuvs .spi .CuVSProvider ;
1414
1515import org .elasticsearch .core .Strings ;
@@ -47,10 +47,12 @@ public interface CuVSResourceManager {
4747 * effect on GPU memory and compute usage to determine whether to give out
4848 * another resource or wait for a resources to be returned before giving out another.
4949 */
50- // numVectors and dims are currently unused, but could be used along with GPU metadata,
51- // memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
52- // to give out a resources or not.
53- ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix .DataType dataType ) throws InterruptedException ;
50+ ManagedCuVSResources acquire (
51+ int numVectors ,
52+ int dims ,
53+ CuVSMatrix .DataType dataType ,
54+ CagraIndexParams .CagraGraphBuildAlgo graphBuildAlgo
55+ ) throws InterruptedException ;
5456
5557 /** Marks the resources as finished with regard to compute. */
5658 void finishedComputation (ManagedCuVSResources resources );
@@ -80,31 +82,31 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
8082 static class Holder {
8183 static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager (
8284 MAX_RESOURCES ,
83- CuVSProvider .provider ().gpuInfoProvider ()
85+ new RealGPUMemoryService ( CuVSProvider .provider ().gpuInfoProvider () )
8486 );
8587 }
8688
8789 private final ManagedCuVSResources [] pool ;
8890 private final int capacity ;
89- private final GPUInfoProvider gpuInfoProvider ;
91+ private final GPUMemoryService gpuMemoryService ;
9092 private int createdCount ;
9193
9294 ReentrantLock lock = new ReentrantLock ();
9395 Condition enoughResourcesCondition = lock .newCondition ();
9496
95- public PoolingCuVSResourceManager (int capacity , GPUInfoProvider gpuInfoProvider ) {
97+ PoolingCuVSResourceManager (int capacity , GPUMemoryService gpuMemoryService ) {
9698 if (capacity < 1 || capacity > MAX_RESOURCES ) {
9799 throw new IllegalArgumentException ("Resource count must be between 1 and " + MAX_RESOURCES );
98100 }
99101 this .capacity = capacity ;
100- this .gpuInfoProvider = gpuInfoProvider ;
102+ this .gpuMemoryService = gpuMemoryService ;
101103 this .pool = new ManagedCuVSResources [MAX_RESOURCES ];
102104 }
103105
104106 private ManagedCuVSResources getResourceFromPool () {
105107 for (int i = 0 ; i < createdCount ; ++i ) {
106108 var res = pool [i ];
107- if (res .locked == false ) {
109+ if (res .isLocked () == false ) {
108110 return res ;
109111 }
110112 }
@@ -120,43 +122,49 @@ private int numLockedResources() {
120122 int lockedResources = 0 ;
121123 for (int i = 0 ; i < createdCount ; ++i ) {
122124 var res = pool [i ];
123- if (res .locked ) {
125+ if (res .isLocked () ) {
124126 lockedResources ++;
125127 }
126128 }
127129 return lockedResources ;
128130 }
129131
130132 @ Override
131- public ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix .DataType dataType ) throws InterruptedException {
133+ public ManagedCuVSResources acquire (
134+ int numVectors ,
135+ int dims ,
136+ CuVSMatrix .DataType dataType ,
137+ CagraIndexParams .CagraGraphBuildAlgo graphBuildAlgo
138+ ) throws InterruptedException {
132139 try {
133140 var started = System .nanoTime ();
134141 lock .lock ();
135142
136143 boolean allConditionsMet = false ;
137144 ManagedCuVSResources res = null ;
145+
146+ long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims , dataType , graphBuildAlgo );
147+ logger .debug (
148+ "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]" ,
149+ numVectors ,
150+ dims ,
151+ dataType .name (),
152+ requiredMemoryInBytes
153+ );
154+
138155 while (allConditionsMet == false ) {
139156 res = getResourceFromPool ();
140157
141158 final boolean enoughMemory ;
142159 if (res != null ) {
143- long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims , dataType );
144- logger .debug (
145- "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]" ,
146- numVectors ,
147- dims ,
148- dataType .name (),
149- requiredMemoryInBytes
150- );
151-
152160 // Check immutable constraints
153- long totalDeviceMemoryInBytes = gpuInfoProvider . getCurrentInfo (res ). totalDeviceMemoryInBytes ( );
154- if (requiredMemoryInBytes > totalDeviceMemoryInBytes ) {
161+ long totalMemoryInBytes = gpuMemoryService . totalMemoryInBytes (res );
162+ if (requiredMemoryInBytes > totalMemoryInBytes ) {
155163 String message = Strings .format (
156164 "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]" ,
157165 numVectors ,
158166 dims ,
159- totalDeviceMemoryInBytes
167+ totalMemoryInBytes
160168 );
161169 logger .error (message );
162170 throw new IllegalArgumentException (message );
@@ -169,9 +177,9 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
169177 }
170178
171179 // Check resources availability
172- long freeDeviceMemoryInBytes = gpuInfoProvider . getCurrentInfo (res ). freeDeviceMemoryInBytes ( );
173- enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes ;
174- logger .debug ("Free device memory [{} B], enoughMemory[{}]" , freeDeviceMemoryInBytes , enoughMemory );
180+ long availableMemoryInBytes = gpuMemoryService . availableMemoryInBytes (res );
181+ enoughMemory = requiredMemoryInBytes <= availableMemoryInBytes ;
182+ logger .debug ("Free device memory [{} B], enoughMemory[{}]" , availableMemoryInBytes , enoughMemory );
175183 } else {
176184 logger .debug ("No resources available in pool" );
177185 enoughMemory = false ;
@@ -184,14 +192,20 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
184192 }
185193 var elapsed = started - System .nanoTime ();
186194 logger .debug ("Resource acquired in [{}ms]" , elapsed / 1_000_000.0 );
187- res .locked = true ;
195+ gpuMemoryService .reserveMemory (requiredMemoryInBytes );
196+ res .lock (() -> gpuMemoryService .releaseMemory (requiredMemoryInBytes ));
188197 return res ;
189198 } finally {
190199 lock .unlock ();
191200 }
192201 }
193202
194- private long estimateRequiredMemory (int numVectors , int dims , CuVSMatrix .DataType dataType ) {
203+ private long estimateRequiredMemory (
204+ int numVectors ,
205+ int dims ,
206+ CuVSMatrix .DataType dataType ,
207+ CagraIndexParams .CagraGraphBuildAlgo graphBuildAlgo
208+ ) {
195209 int elementTypeBytes = switch (dataType ) {
196210 case FLOAT -> Float .BYTES ;
197211 case INT , UINT -> Integer .BYTES ;
@@ -217,8 +231,8 @@ public void release(ManagedCuVSResources resources) {
217231 logger .debug ("Releasing resources to pool" );
218232 try {
219233 lock .lock ();
220- assert resources .locked ;
221- resources .locked = false ;
234+ assert resources .isLocked () ;
235+ resources .unlock () ;
222236 enoughResourcesCondition .signalAll ();
223237 } finally {
224238 lock .unlock ();
@@ -238,8 +252,9 @@ public void shutdown() {
238252 /** A managed resource. Cannot be closed. */
239253 final class ManagedCuVSResources implements CuVSResources {
240254
241- final CuVSResources delegate ;
242- boolean locked = false ;
255+ private final CuVSResources delegate ;
256+ private static final Runnable NOT_LOCKED = () -> {};
257+ private Runnable unlockAction = NOT_LOCKED ;
243258
244259 ManagedCuVSResources (CuVSResources resources ) {
245260 this .delegate = resources ;
@@ -269,5 +284,17 @@ public Path tempDirectory() {
269284 public String toString () {
270285 return "ManagedCuVSResources[delegate=" + delegate + "]" ;
271286 }
287+
288+ void lock (Runnable unlockAction ) {
289+ this .unlockAction = unlockAction ;
290+ }
291+
292+ void unlock () {
293+ unlockAction = NOT_LOCKED ;
294+ }
295+
296+ boolean isLocked () {
297+ return unlockAction != NOT_LOCKED ;
298+ }
272299 }
273300}
0 commit comments