99
1010import com .nvidia .cuvs .CuVSResources ;
1111
12+ import com .nvidia .cuvs .GPUInfoProvider ;
13+
14+ import com .nvidia .cuvs .spi .CuVSProvider ;
15+
16+ import org .elasticsearch .core .Strings ;
1217import org .elasticsearch .xpack .gpu .GPUSupport ;
1318
1419import java .nio .file .Path ;
1520import java .util .Objects ;
16- import java .util .concurrent .ArrayBlockingQueue ;
17- import java .util .concurrent .BlockingQueue ;
21+ import java .util .concurrent .locks . Condition ;
22+ import java .util .concurrent .locks . ReentrantLock ;
1823
1924/**
2025 * A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU.
@@ -65,35 +70,84 @@ static CuVSResourceManager pooling() {
6570 */
6671 class PoolingCuVSResourceManager implements CuVSResourceManager {
6772
73+ /** A multiplier on input data to account for intermediate and output data size required while processing it */
74+ static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0 ;
6875 static final int MAX_RESOURCES = 2 ;
69- static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager (MAX_RESOURCES );
76+ static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager (
77+ MAX_RESOURCES ,
78+ CuVSProvider .provider ().gpuInfoProvider ()
79+ );
7080
71- final BlockingQueue < ManagedCuVSResources > pool ;
81+ final ManagedCuVSResources [] pool ;
7282 final int capacity ;
83+ final GPUInfoProvider gpuInfoProvider ;
7384 int createdCount ;
7485
75- public PoolingCuVSResourceManager (int capacity ) {
86+ ReentrantLock lock = new ReentrantLock ();
87+ Condition poolAvailableCondition = lock .newCondition ();
88+ Condition enoughMemoryCondition = lock .newCondition ();
89+
90+ public PoolingCuVSResourceManager (int capacity , GPUInfoProvider gpuInfoProvider ) {
7691 if (capacity < 1 || capacity > MAX_RESOURCES ) {
7792 throw new IllegalArgumentException ("Resource count must be between 1 and " + MAX_RESOURCES );
7893 }
7994 this .capacity = capacity ;
80- this .pool = new ArrayBlockingQueue <>(capacity );
95+ this .gpuInfoProvider = gpuInfoProvider ;
96+ this .pool = new ManagedCuVSResources [MAX_RESOURCES ];
8197 }
8298
83- @ Override
84- public ManagedCuVSResources acquire (int numVectors , int dims ) throws InterruptedException {
85- ManagedCuVSResources res = pool .poll ();
86- if (res != null ) {
99+ private ManagedCuVSResources getResourceFromPool () {
100+ for (int i = 0 ; i < createdCount ; ++i ) {
101+ var res = pool [i ];
102+ if (res .locked == false ) {
103+ return res ;
104+ }
105+ }
106+ if (createdCount < capacity ) {
107+ var res = new ManagedCuVSResources (Objects .requireNonNull (createNew ()));
108+ pool [createdCount ++] = res ;
87109 return res ;
88110 }
89- synchronized (this ) {
90- if (createdCount < capacity ) {
91- createdCount ++;
92- return new ManagedCuVSResources (Objects .requireNonNull (createNew ()));
111+ return null ;
112+ }
113+
114+ @ Override
115+ public ManagedCuVSResources acquire (int numVectors , int dims ) throws InterruptedException {
116+ try {
117+ lock .lock ();
118+ ManagedCuVSResources res ;
119+ while ((res = getResourceFromPool ()) == null ) {
120+ poolAvailableCondition .await ();
121+ }
122+
123+ // Check resources availability
124+ var resourcesInfo = gpuInfoProvider .getCurrentInfo (res );
125+
126+ // Memory
127+ long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims );
128+ if (requiredMemoryInBytes > resourcesInfo .totalDeviceMemoryInBytes ()) {
129+ throw new IllegalArgumentException (
130+ Strings .format (
131+ "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]" ,
132+ numVectors ,
133+ dims ,
134+ resourcesInfo .totalDeviceMemoryInBytes () / 1048576.0f
135+ )
136+ );
137+ }
138+ while (requiredMemoryInBytes > resourcesInfo .freeDeviceMemoryInBytes ()) {
139+ enoughMemoryCondition .await ();
93140 }
141+
142+ res .locked = true ;
143+ return res ;
144+ } finally {
145+ lock .unlock ();
94146 }
95- // Otherwise, wait for one to be released
96- return pool .take ();
147+ }
148+
149+ private long estimateRequiredMemory (int numVectors , int dims ) {
150+ return (long )(GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * Float .BYTES );
97151 }
98152
99153 // visible for testing
@@ -104,27 +158,37 @@ protected CuVSResources createNew() {
104158 @ Override
105159 public void finishedComputation (ManagedCuVSResources resources ) {
106160 // currently does nothing, but could allow acquire to return possibly blocked resources
161+ // something like enoughComputationCondition.signal()?
107162 }
108163
109164 @ Override
110165 public void release (ManagedCuVSResources resources ) {
111- var added = pool .offer (Objects .requireNonNull (resources ));
112- assert added : "Failed to release resource back to pool" ;
166+ try {
167+ lock .lock ();
168+ assert resources .locked ;
169+ resources .locked = false ;
170+ poolAvailableCondition .signalAll ();
171+ enoughMemoryCondition .signalAll ();
172+ } finally {
173+ lock .unlock ();
174+ }
113175 }
114176
115177 @ Override
116178 public void shutdown () {
117- for (ManagedCuVSResources res : pool ) {
179+ for (int i = 0 ; i < createdCount ; ++i ) {
180+ var res = pool [i ];
181+ assert res != null ;
118182 res .delegate .close ();
119183 }
120- pool .clear ();
121184 }
122185 }
123186
124187 /** A managed resource. Cannot be closed. */
125188 final class ManagedCuVSResources implements CuVSResources {
126189
127190 final CuVSResources delegate ;
191+ boolean locked = false ;
128192
129193 ManagedCuVSResources (CuVSResources resources ) {
130194 this .delegate = resources ;
0 commit comments