1616import org .elasticsearch .core .Strings ;
1717import org .elasticsearch .xpack .gpu .GPUSupport ;
1818
19+ import java .lang .foreign .Arena ;
1920import java .nio .file .Path ;
2021import java .util .Objects ;
2122import java .util .concurrent .locks .Condition ;
@@ -65,34 +66,55 @@ static CuVSResourceManager pooling() {
6566 return PoolingCuVSResourceManager .INSTANCE ;
6667 }
6768
69+ @ FunctionalInterface
70+ interface GpuInfoFunction {
71+ long get (CuVSResources resources );
72+ }
73+
6874 /**
6975 * A manager that maintains a pool of resources.
7076 */
7177 class PoolingCuVSResourceManager implements CuVSResourceManager {
7278
7379 /** A multiplier on input data to account for intermediate and output data size required while processing it */
7480 static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0 ;
81+ static final int GPU_UTILIZATION_MAX_PERCENT = 80 ;
7582 static final int MAX_RESOURCES = 2 ;
83+ static final GPUInfoProvider gpuInfoProvider = CuVSProvider .provider ().gpuInfoProvider ();
7684 static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager (
7785 MAX_RESOURCES ,
78- CuVSProvider .provider ().gpuInfoProvider ()
86+ res ->gpuInfoProvider .getCurrentInfo (res ).totalDeviceMemoryInBytes (),
87+ res ->gpuInfoProvider .getCurrentInfo (res ).freeDeviceMemoryInBytes (),
88+ PoolingCuVSResourceManager ::getGpuUtilizationPercent
7989 );
8090
8191 private final ManagedCuVSResources [] pool ;
8292 private final int capacity ;
83- private final GPUInfoProvider gpuInfoProvider ;
8493 private int createdCount ;
8594
86- ReentrantLock lock = new ReentrantLock ();
87- Condition enoughMemoryCondition = lock .newCondition ();
95+ private final GpuInfoFunction totalMemoryInBytesProvider ;
96+ private final GpuInfoFunction freeMemoryInBytesProvider ;
97+ private final GpuInfoFunction gpuUtilizationPercentProvider ;
8898
89- public PoolingCuVSResourceManager (int capacity , GPUInfoProvider gpuInfoProvider ) {
99+ ReentrantLock lock = new ReentrantLock ();
100+ Condition enoughResourcesCondition = lock .newCondition ();
101+
102+ public PoolingCuVSResourceManager (
103+ int capacity ,
104+ GpuInfoFunction totalMemoryInBytesProvider ,
105+ GpuInfoFunction freeMemoryInBytesProvider ,
106+ GpuInfoFunction gpuUtilizationPercentProvider
107+ ) {
108+ this .totalMemoryInBytesProvider = totalMemoryInBytesProvider ;
109+ this .freeMemoryInBytesProvider = freeMemoryInBytesProvider ;
110+ this .gpuUtilizationPercentProvider = gpuUtilizationPercentProvider ;
90111 if (capacity < 1 || capacity > MAX_RESOURCES ) {
91112 throw new IllegalArgumentException ("Resource count must be between 1 and " + MAX_RESOURCES );
92113 }
93114 this .capacity = capacity ;
94- this .gpuInfoProvider = gpuInfoProvider ;
95115 this .pool = new ManagedCuVSResources [MAX_RESOURCES ];
116+
117+ NVML .nvmlInit_v2 ();
96118 }
97119
98120 private ManagedCuVSResources getResourceFromPool () {
@@ -130,35 +152,38 @@ public ManagedCuVSResources acquire(int numVectors, int dims) throws Interrupted
130152 ManagedCuVSResources res = null ;
131153 while (allConditionsMet == false ) {
132154 res = getResourceFromPool ();
133- // If no resource in the pool is locked, short circuit to avoid livelock
134- if (numLockedResources () == 0 ) {
135- break ;
136- }
155+
137156 final boolean enoughMemory ;
157+ final boolean enoughComputation ;
138158 if (res != null ) {
159+ // If no resource in the pool is locked, short circuit to avoid livelock
160+ if (numLockedResources () == 0 ) {
161+ break ;
162+ }
163+
139164 // Check resources availability
140- // Memory
141165 long requiredMemoryInBytes = estimateRequiredMemory (numVectors , dims );
142- if (requiredMemoryInBytes > gpuInfoProvider . getCurrentInfo (res ). totalDeviceMemoryInBytes ( )) {
166+ if (requiredMemoryInBytes > totalMemoryInBytesProvider . get (res )) {
143167 throw new IllegalArgumentException (
144168 Strings .format (
145169 "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]" ,
146170 numVectors ,
147171 dims ,
148- gpuInfoProvider . getCurrentInfo (res ). totalDeviceMemoryInBytes ( ) / (1024L * 1024L )
172+ totalMemoryInBytesProvider . get (res ) / (1024L * 1024L )
149173 )
150174 );
151175 }
152- enoughMemory = requiredMemoryInBytes <= gpuInfoProvider .getCurrentInfo (res ).freeDeviceMemoryInBytes ();
176+ enoughMemory = requiredMemoryInBytes <= freeMemoryInBytesProvider .get (res );
177+ enoughComputation = gpuUtilizationPercentProvider .get (res ) < GPU_UTILIZATION_MAX_PERCENT ;
153178 } else {
154179 enoughMemory = false ;
155- }
156- if (enoughMemory == false ) {
157- enoughMemoryCondition .await ();
180+ enoughComputation = false ;
158181 }
159182
160- // TODO: add enoughComputation / enoughComputationCondition here
161- allConditionsMet = enoughMemory ; // && enoughComputation
183+ allConditionsMet = enoughMemory && enoughComputation ;
184+ if (allConditionsMet == false ) {
185+ enoughResourcesCondition .await ();
186+ }
162187 }
163188 res .locked = true ;
164189 return res ;
@@ -171,15 +196,24 @@ private long estimateRequiredMemory(int numVectors, int dims) {
171196 return (long )(GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * Float .BYTES );
172197 }
173198
199+ private static int getGpuUtilizationPercent (CuVSResources resources ) {
200+ try (var localArena = Arena .ofConfined ()) {
201+ var deviceHandle = NVML .nvmlDeviceGetHandleByIndex_v2 (resources .deviceId ());
202+ var nvmlUtilizationPtr = localArena .allocate (NVML .nvmlUtilization_t .layout ());
203+ NVML .nvmlDeviceGetUtilizationRates (deviceHandle , nvmlUtilizationPtr );
204+ return NVML .nvmlUtilization_t .gpu (nvmlUtilizationPtr );
205+ }
206+ }
207+
174208 // visible for testing
175209 protected CuVSResources createNew () {
176210 return GPUSupport .cuVSResourcesOrNull (true );
177211 }
178212
179213 @ Override
180214 public void finishedComputation (ManagedCuVSResources resources ) {
181- // currently does nothing, but could allow acquire to return possibly blocked resources
182- // something like enoughComputationCondition .signalAll()?
215+ // Allow acquire to return possibly blocked resources
216+ enoughResourcesCondition .signalAll ();
183217 }
184218
185219 @ Override
@@ -188,7 +222,7 @@ public void release(ManagedCuVSResources resources) {
188222 lock .lock ();
189223 assert resources .locked ;
190224 resources .locked = false ;
191- enoughMemoryCondition .signalAll ();
225+ enoughResourcesCondition .signalAll ();
192226 } finally {
193227 lock .unlock ();
194228 }
@@ -201,6 +235,7 @@ public void shutdown() {
201235 assert res != null ;
202236 res .delegate .close ();
203237 }
238+ NVML .nvmlShutdown ();
204239 }
205240 }
206241
0 commit comments