77
88package org .elasticsearch .xpack .gpu .codec ;
99
10+ import com .nvidia .cuvs .CuVSMatrix ;
1011import com .nvidia .cuvs .CuVSResources ;
12+ import com .nvidia .cuvs .GPUInfoProvider ;
13+ import com .nvidia .cuvs .spi .CuVSProvider ;
1114
15+ import org .elasticsearch .core .Strings ;
16+ import org .elasticsearch .logging .LogManager ;
17+ import org .elasticsearch .logging .Logger ;
1218import org .elasticsearch .xpack .gpu .GPUSupport ;
1319
1420import java .nio .file .Path ;
1521import java .util .Objects ;
16- import java .util .concurrent .ArrayBlockingQueue ;
17- import java .util .concurrent .BlockingQueue ;
22+ import java .util .concurrent .locks . Condition ;
23+ import java .util .concurrent .locks . ReentrantLock ;
1824
1925/**
2026 * A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU.
@@ -44,7 +50,7 @@ public interface CuVSResourceManager {
4450 // numVectors and dims are currently unused, but could be used along with GPU metadata,
4551 // memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
4652 // to give out a resources or not.
47- ManagedCuVSResources acquire (int numVectors , int dims ) throws InterruptedException ;
53+ ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix . DataType dataType ) throws InterruptedException ;
4854
4955 /** Marks the resources as finished with regard to compute. */
5056 void finishedComputation (ManagedCuVSResources resources );
@@ -65,35 +71,127 @@ static CuVSResourceManager pooling() {
6571 */
6672 class PoolingCuVSResourceManager implements CuVSResourceManager {
6773
74+ static final Logger logger = LogManager .getLogger (CuVSResourceManager .class );
75+
76+ /** A multiplier on input data to account for intermediate and output data size required while processing it */
77+ static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0 ;
6878 static final int MAX_RESOURCES = 2 ;
69- static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager (MAX_RESOURCES );
79+ static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager (
80+ MAX_RESOURCES ,
81+ CuVSProvider .provider ().gpuInfoProvider ()
82+ );
83+
84+ private final ManagedCuVSResources [] pool ;
85+ private final int capacity ;
86+ private final GPUInfoProvider gpuInfoProvider ;
87+ private int createdCount ;
7088
71- final BlockingQueue <ManagedCuVSResources > pool ;
72- final int capacity ;
73- int createdCount ;
89+ ReentrantLock lock = new ReentrantLock ();
90+ Condition enoughResourcesCondition = lock .newCondition ();
7491
75- public PoolingCuVSResourceManager (int capacity ) {
92+ public PoolingCuVSResourceManager (int capacity , GPUInfoProvider gpuInfoProvider ) {
7693 if (capacity < 1 || capacity > MAX_RESOURCES ) {
7794 throw new IllegalArgumentException ("Resource count must be between 1 and " + MAX_RESOURCES );
7895 }
7996 this .capacity = capacity ;
80- this .pool = new ArrayBlockingQueue <>(capacity );
97+ this .gpuInfoProvider = gpuInfoProvider ;
98+ this .pool = new ManagedCuVSResources [MAX_RESOURCES ];
8199 }
82100
83- @ Override
84- public ManagedCuVSResources acquire (int numVectors , int dims ) throws InterruptedException {
85- ManagedCuVSResources res = pool .poll ();
86- if (res != null ) {
101+ private ManagedCuVSResources getResourceFromPool () {
102+ for (int i = 0 ; i < createdCount ; ++i ) {
103+ var res = pool [i ];
104+ if (res .locked == false ) {
105+ return res ;
106+ }
107+ }
108+ if (createdCount < capacity ) {
109+ var res = new ManagedCuVSResources (Objects .requireNonNull (createNew ()));
110+ pool [createdCount ++] = res ;
87111 return res ;
88112 }
89- synchronized (this ) {
90- if (createdCount < capacity ) {
91- createdCount ++;
92- return new ManagedCuVSResources (Objects .requireNonNull (createNew ()));
113+ return null ;
114+ }
115+
116+ private int numLockedResources () {
117+ int lockedResources = 0 ;
118+ for (int i = 0 ; i < createdCount ; ++i ) {
119+ var res = pool [i ];
120+ if (res .locked ) {
121+ lockedResources ++;
93122 }
94123 }
95- // Otherwise, wait for one to be released
96- return pool .take ();
124+ return lockedResources ;
125+ }
126+
127+ @ Override
128+ public ManagedCuVSResources acquire (int numVectors , int dims , CuVSMatrix .DataType dataType ) throws InterruptedException {
129+ try {
130+ lock .lock ();
131+
132+ boolean allConditionsMet = false ;
133+ ManagedCuVSResources res = null ;
134+ while (allConditionsMet == false ) {
135+ res = getResourceFromPool ();
136+
137+ final boolean enoughMemory ;
138+ if (res != null ) {
139+ long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims , dataType );
140+ logger .info (
141+ "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]" ,
142+ numVectors ,
143+ dims ,
144+ dataType .name (),
145+ requiredMemoryInBytes
146+ );
147+
148+ // Check immutable constraints
149+ long totalDeviceMemoryInBytes = gpuInfoProvider .getCurrentInfo (res ).totalDeviceMemoryInBytes ();
150+ if (requiredMemoryInBytes > totalDeviceMemoryInBytes ) {
151+ String message = Strings .format (
152+ "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]" ,
153+ numVectors ,
154+ dims ,
155+ totalDeviceMemoryInBytes
156+ );
157+ logger .error (message );
158+ throw new IllegalArgumentException (message );
159+ }
160+
161+ // If no resource in the pool is locked, short circuit to avoid livelock
162+ if (numLockedResources () == 0 ) {
163+ logger .info ("No resources currently locked, proceeding" );
164+ break ;
165+ }
166+
167+ // Check resources availability
168+ long freeDeviceMemoryInBytes = gpuInfoProvider .getCurrentInfo (res ).freeDeviceMemoryInBytes ();
169+ enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes ;
170+ logger .info ("Free device memory [{} B], enoughMemory[{}]" , freeDeviceMemoryInBytes );
171+ } else {
172+ logger .info ("No resources available in pool" );
173+ enoughMemory = false ;
174+ }
175+ // TODO: add enoughComputation / enoughComputationCondition here
176+ allConditionsMet = enoughMemory ; // && enoughComputation
177+ if (allConditionsMet == false ) {
178+ enoughResourcesCondition .await ();
179+ }
180+ }
181+ res .locked = true ;
182+ return res ;
183+ } finally {
184+ lock .unlock ();
185+ }
186+ }
187+
188+ private long estimateRequiredMemory (int numVectors , int dims , CuVSMatrix .DataType dataType ) {
189+ int elementTypeBytes = switch (dataType ) {
190+ case FLOAT -> Float .BYTES ;
191+ case INT , UINT -> Integer .BYTES ;
192+ case BYTE -> Byte .BYTES ;
193+ };
194+ return (long ) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes );
97195 }
98196
99197 // visible for testing
@@ -103,28 +201,39 @@ protected CuVSResources createNew() {
103201
104202 @ Override
105203 public void finishedComputation (ManagedCuVSResources resources ) {
204+ logger .info ("Computation finished" );
106205 // currently does nothing, but could allow acquire to return possibly blocked resources
206+ // enoughResourcesCondition.signalAll()
107207 }
108208
109209 @ Override
110210 public void release (ManagedCuVSResources resources ) {
111- var added = pool .offer (Objects .requireNonNull (resources ));
112- assert added : "Failed to release resource back to pool" ;
211+ logger .info ("Releasing resources to pool" );
212+ try {
213+ lock .lock ();
214+ assert resources .locked ;
215+ resources .locked = false ;
216+ enoughResourcesCondition .signalAll ();
217+ } finally {
218+ lock .unlock ();
219+ }
113220 }
114221
115222 @ Override
116223 public void shutdown () {
117- for (ManagedCuVSResources res : pool ) {
224+ for (int i = 0 ; i < createdCount ; ++i ) {
225+ var res = pool [i ];
226+ assert res != null ;
118227 res .delegate .close ();
119228 }
120- pool .clear ();
121229 }
122230 }
123231
124232 /** A managed resource. Cannot be closed. */
125233 final class ManagedCuVSResources implements CuVSResources {
126234
127235 final CuVSResources delegate ;
236+ boolean locked = false ;
128237
129238 ManagedCuVSResources (CuVSResources resources ) {
130239 this .delegate = resources ;
@@ -135,6 +244,11 @@ public ScopedAccess access() {
135244 return delegate .access ();
136245 }
137246
247+ @ Override
248+ public int deviceId () {
249+ return delegate .deviceId ();
250+ }
251+
138252 @ Override
139253 public void close () {
140254 throw new UnsupportedOperationException ("this resource is managed, cannot be closed by clients" );
0 commit comments