77
88package org .elasticsearch .xpack .gpu .codec ;
99
10+ import com .nvidia .cuvs .CagraIndexParams ;
1011import com .nvidia .cuvs .CuVSMatrix ;
1112import com .nvidia .cuvs .CuVSResources ;
12- import com .nvidia .cuvs .GPUInfoProvider ;
1313import com .nvidia .cuvs .spi .CuVSProvider ;
1414
1515import org .elasticsearch .core .Strings ;
@@ -47,10 +47,8 @@ public interface CuVSResourceManager {
4747 * effect on GPU memory and compute usage to determine whether to give out
4848 * another resource or wait for a resources to be returned before giving out another.
4949 */
50- // numVectors and dims are currently unused, but could be used along with GPU metadata,
51- // memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
52- // to give out a resources or not.
53- ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix .DataType dataType ) throws InterruptedException ;
50+ ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix .DataType dataType , CagraIndexParams cagraIndexParams )
51+ throws InterruptedException ;
5452
5553 /** Marks the resources as finished with regard to compute. */
5654 void finishedComputation (ManagedCuVSResources resources );
@@ -80,31 +78,31 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
8078 static class Holder {
8179 static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager (
8280 MAX_RESOURCES ,
83- CuVSProvider .provider ().gpuInfoProvider ()
81+ new RealGPUMemoryService ( CuVSProvider .provider ().gpuInfoProvider () )
8482 );
8583 }
8684
8785 private final ManagedCuVSResources [] pool ;
8886 private final int capacity ;
89- private final GPUInfoProvider gpuInfoProvider ;
87+ private final GPUMemoryService gpuMemoryService ;
9088 private int createdCount ;
9189
9290 ReentrantLock lock = new ReentrantLock ();
9391 Condition enoughResourcesCondition = lock .newCondition ();
9492
95- public PoolingCuVSResourceManager (int capacity , GPUInfoProvider gpuInfoProvider ) {
93+ PoolingCuVSResourceManager (int capacity , GPUMemoryService gpuMemoryService ) {
9694 if (capacity < 1 || capacity > MAX_RESOURCES ) {
9795 throw new IllegalArgumentException ("Resource count must be between 1 and " + MAX_RESOURCES );
9896 }
9997 this .capacity = capacity ;
100- this .gpuInfoProvider = gpuInfoProvider ;
98+ this .gpuMemoryService = gpuMemoryService ;
10199 this .pool = new ManagedCuVSResources [MAX_RESOURCES ];
102100 }
103101
104102 private ManagedCuVSResources getResourceFromPool () {
105103 for (int i = 0 ; i < createdCount ; ++i ) {
106104 var res = pool [i ];
107- if (res .locked == false ) {
105+ if (res .isLocked () == false ) {
108106 return res ;
109107 }
110108 }
@@ -120,43 +118,45 @@ private int numLockedResources() {
120118 int lockedResources = 0 ;
121119 for (int i = 0 ; i < createdCount ; ++i ) {
122120 var res = pool [i ];
123- if (res .locked ) {
121+ if (res .isLocked () ) {
124122 lockedResources ++;
125123 }
126124 }
127125 return lockedResources ;
128126 }
129127
130128 @ Override
131- public ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix .DataType dataType ) throws InterruptedException {
129+ public ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix .DataType dataType , CagraIndexParams cagraIndexParams )
130+ throws InterruptedException {
132131 try {
133132 var started = System .nanoTime ();
134133 lock .lock ();
135134
136135 boolean allConditionsMet = false ;
137136 ManagedCuVSResources res = null ;
137+
138+ long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims , dataType , cagraIndexParams );
139+ logger .debug (
140+ "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]" ,
141+ numVectors ,
142+ dims ,
143+ dataType .name (),
144+ requiredMemoryInBytes
145+ );
146+
138147 while (allConditionsMet == false ) {
139148 res = getResourceFromPool ();
140149
141150 final boolean enoughMemory ;
142151 if (res != null ) {
143- long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims , dataType );
144- logger .debug (
145- "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]" ,
146- numVectors ,
147- dims ,
148- dataType .name (),
149- requiredMemoryInBytes
150- );
151-
152152 // Check immutable constraints
153- long totalDeviceMemoryInBytes = gpuInfoProvider . getCurrentInfo (res ). totalDeviceMemoryInBytes ( );
154- if (requiredMemoryInBytes > totalDeviceMemoryInBytes ) {
153+ long totalMemoryInBytes = gpuMemoryService . totalMemoryInBytes (res );
154+ if (requiredMemoryInBytes > totalMemoryInBytes ) {
155155 String message = Strings .format (
156156 "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]" ,
157157 numVectors ,
158158 dims ,
159- totalDeviceMemoryInBytes
159+ totalMemoryInBytes
160160 );
161161 logger .error (message );
162162 throw new IllegalArgumentException (message );
@@ -169,9 +169,9 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
169169 }
170170
171171 // Check resources availability
172- long freeDeviceMemoryInBytes = gpuInfoProvider . getCurrentInfo (res ). freeDeviceMemoryInBytes ( );
173- enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes ;
174- logger .debug ("Free device memory [{} B], enoughMemory[{}]" , freeDeviceMemoryInBytes , enoughMemory );
172+ long availableMemoryInBytes = gpuMemoryService . availableMemoryInBytes (res );
173+ enoughMemory = requiredMemoryInBytes <= availableMemoryInBytes ;
174+ logger .debug ("Free device memory [{} B], enoughMemory[{}]" , availableMemoryInBytes , enoughMemory );
175175 } else {
176176 logger .debug ("No resources available in pool" );
177177 enoughMemory = false ;
@@ -184,19 +184,33 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
184184 }
185185 var elapsed = started - System .nanoTime ();
186186 logger .debug ("Resource acquired in [{}ms]" , elapsed / 1_000_000.0 );
187- res .locked = true ;
187+ gpuMemoryService .reserveMemory (requiredMemoryInBytes );
188+ res .lock (() -> gpuMemoryService .releaseMemory (requiredMemoryInBytes ));
188189 return res ;
189190 } finally {
190191 lock .unlock ();
191192 }
192193 }
193194
194- private long estimateRequiredMemory (int numVectors , int dims , CuVSMatrix .DataType dataType ) {
195+ private long estimateRequiredMemory (int numVectors , int dims , CuVSMatrix .DataType dataType , CagraIndexParams cagraIndexParams ) {
195196 int elementTypeBytes = switch (dataType ) {
196197 case FLOAT -> Float .BYTES ;
197198 case INT , UINT -> Integer .BYTES ;
198199 case BYTE -> Byte .BYTES ;
199200 };
201+
202+ if (cagraIndexParams .getCagraGraphBuildAlgo () == CagraIndexParams .CagraGraphBuildAlgo .IVF_PQ
203+ && cagraIndexParams .getCuVSIvfPqParams () != null
204+ && cagraIndexParams .getCuVSIvfPqParams ().getIndexParams () != null
205+ && cagraIndexParams .getCuVSIvfPqParams ().getIndexParams ().getPqDim () != 0 ) {
206+ // See https://docs.rapids.ai/api/cuvs/nightly/neighbors/ivfpq/#index-device-memory
207+ var pqDim = cagraIndexParams .getCuVSIvfPqParams ().getIndexParams ().getPqDim ();
208+ var pqBits = cagraIndexParams .getCuVSIvfPqParams ().getIndexParams ().getPqBits ();
209+ var numClusters = cagraIndexParams .getCuVSIvfPqParams ().getIndexParams ().getnLists ();
210+ var approximatedIvfBytes = numVectors * (pqDim * (pqBits / 8.0 ) + elementTypeBytes ) + (long ) numClusters * Integer .BYTES ;
211+ return (long ) (GPU_COMPUTATION_MEMORY_FACTOR * approximatedIvfBytes );
212+ }
213+
200214 return (long ) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes );
201215 }
202216
@@ -217,8 +231,8 @@ public void release(ManagedCuVSResources resources) {
217231 logger .debug ("Releasing resources to pool" );
218232 try {
219233 lock .lock ();
220- assert resources .locked ;
221- resources .locked = false ;
234+ assert resources .isLocked () ;
235+ resources .unlock () ;
222236 enoughResourcesCondition .signalAll ();
223237 } finally {
224238 lock .unlock ();
@@ -238,8 +252,9 @@ public void shutdown() {
238252 /** A managed resource. Cannot be closed. */
239253 final class ManagedCuVSResources implements CuVSResources {
240254
241- final CuVSResources delegate ;
242- boolean locked = false ;
255+ private final CuVSResources delegate ;
256+ private static final Runnable NOT_LOCKED = () -> {};
257+ private Runnable unlockAction = NOT_LOCKED ;
243258
244259 ManagedCuVSResources (CuVSResources resources ) {
245260 this .delegate = resources ;
@@ -269,5 +284,18 @@ public Path tempDirectory() {
269284 public String toString () {
270285 return "ManagedCuVSResources[delegate=" + delegate + "]" ;
271286 }
287+
288+ void lock (Runnable unlockAction ) {
289+ this .unlockAction = unlockAction ;
290+ }
291+
292+ void unlock () {
293+ unlockAction .run ();
294+ unlockAction = NOT_LOCKED ;
295+ }
296+
297+ boolean isLocked () {
298+ return unlockAction != NOT_LOCKED ;
299+ }
272300 }
273301}
0 commit comments