1313import com .nvidia .cuvs .GPUInfo ;
1414import com .nvidia .cuvs .GPUInfoProvider ;
1515
16+ import org .elasticsearch .logging .LogManager ;
17+ import org .elasticsearch .logging .Logger ;
1618import org .elasticsearch .test .ESTestCase ;
1719
1820import java .nio .file .Path ;
21+ import java .util .ArrayList ;
1922import java .util .List ;
2023import java .util .concurrent .atomic .AtomicInteger ;
2124import java .util .concurrent .atomic .AtomicReference ;
25+ import java .util .function .LongSupplier ;
2226
2327import static org .hamcrest .Matchers .anyOf ;
2428import static org .hamcrest .Matchers .containsString ;
29+ import static org .hamcrest .Matchers .equalTo ;
30+ import static org .hamcrest .Matchers .not ;
2531
2632public class CuVSResourceManagerTests extends ESTestCase {
2733
34+ private static final Logger log = LogManager .getLogger (CuVSResourceManagerTests .class );
35+
36+ public static final long TOTAL_DEVICE_MEMORY_IN_BYTES = 256L * 1024 * 1024 ;
37+
2838 public void testBasic () throws InterruptedException {
2939 var mgr = new MockPoolingCuVSResourceManager (2 );
3040 var res1 = mgr .acquire (0 , 0 );
@@ -65,10 +75,52 @@ public void testBlocking() throws Exception {
6575 mgr .shutdown ();
6676 }
6777
78+ public void testBlockingOnInsufficientMemory () throws Exception {
79+ var mgr = new MockPoolingCuVSResourceManager (2 );
80+ var res1 = mgr .acquire (16 * 1024 , 1024 );
81+
82+ AtomicReference <CuVSResources > holder = new AtomicReference <>();
83+ Thread t = new Thread (() -> {
84+ try {
85+ var res2 = mgr .acquire ((16 * 1024 ) + 1 , 1024 );
86+ holder .set (res2 );
87+ } catch (InterruptedException e ) {
88+ throw new AssertionError (e );
89+ }
90+ });
91+ t .start ();
92+ Thread .sleep (1_000 );
93+ assertNull (holder .get ());
94+ mgr .release (res1 );
95+ t .join ();
96+ assertThat (holder .get ().toString (), anyOf (containsString ("id=0" ), containsString ("id=1" )));
97+ mgr .shutdown ();
98+ }
99+
100+ public void testNotBlockingOnSufficientMemory () throws Exception {
101+ var mgr = new MockPoolingCuVSResourceManager (2 );
102+ var res1 = mgr .acquire (16 * 1024 , 1024 );
103+
104+ AtomicReference <CuVSResources > holder = new AtomicReference <>();
105+ Thread t = new Thread (() -> {
106+ try {
107+ var res2 = mgr .acquire ((16 * 1024 ) - 1 , 1024 );
108+ holder .set (res2 );
109+ } catch (InterruptedException e ) {
110+ throw new AssertionError (e );
111+ }
112+ });
113+ t .start ();
114+ t .join (5_000 );
115+ assertNotNull (holder .get ());
116+ assertThat (holder .get ().toString (), not (equalTo (res1 .toString ())));
117+ mgr .shutdown ();
118+ }
119+
68120 public void testManagedResIsNotClosable () throws Exception {
69121 var mgr = new MockPoolingCuVSResourceManager (1 );
70122 var res = mgr .acquire (0 , 0 );
71- assertThrows (UnsupportedOperationException .class , () -> res . close () );
123+ assertThrows (UnsupportedOperationException .class , res :: close );
72124 mgr .release (res );
73125 mgr .shutdown ();
74126 }
@@ -85,16 +137,45 @@ public void testDoubleRelease() throws InterruptedException {
85137
86138 static class MockPoolingCuVSResourceManager extends CuVSResourceManager .PoolingCuVSResourceManager {
87139
88- final AtomicInteger idGenerator = new AtomicInteger ();
140+ private final AtomicInteger idGenerator = new AtomicInteger ();
141+ private final List <Long > allocations ;
89142
90143 MockPoolingCuVSResourceManager (int capacity ) {
91- super (capacity , new MockGPUInfoProvider ());
144+ this (capacity , new ArrayList <>());
145+ }
146+
147+ private MockPoolingCuVSResourceManager (int capacity , List <Long > allocationList ) {
148+ super (capacity , new MockGPUInfoProvider (() -> freeMemoryFunction (allocationList )));
149+ this .allocations = allocationList ;
150+ }
151+
152+ private static long freeMemoryFunction (List <Long > allocations ) {
153+ return TOTAL_DEVICE_MEMORY_IN_BYTES - allocations .stream ().mapToLong (x -> x ).sum ();
92154 }
93155
94156 @ Override
95157 protected CuVSResources createNew () {
96158 return new MockCuVSResources (idGenerator .getAndIncrement ());
97159 }
160+
161+ @ Override
162+ public ManagedCuVSResources acquire (int numVectors , int dims ) throws InterruptedException {
163+ var res = super .acquire (numVectors , dims );
164+ long memory = (long )(numVectors * dims * Float .BYTES *
165+ CuVSResourceManager .PoolingCuVSResourceManager .GPU_COMPUTATION_MEMORY_FACTOR );
166+ allocations .add (memory );
167+ log .info ("Added [{}]" , memory );
168+ return res ;
169+ }
170+
171+ @ Override
172+ public void release (ManagedCuVSResources resources ) {
173+ if (allocations .isEmpty () == false ) {
174+ var x = allocations .removeLast ();
175+ log .info ("Removed [{}]" , x );
176+ }
177+ super .release (resources );
178+ }
98179 }
99180
100181 static class MockCuVSResources implements CuVSResources {
@@ -110,6 +191,11 @@ public ScopedAccess access() {
110191 throw new UnsupportedOperationException ();
111192 }
112193
194+ @ Override
195+ public int deviceId () {
196+ return 0 ;
197+ }
198+
113199 @ Override
114200 public void close () {}
115201
@@ -125,6 +211,12 @@ public String toString() {
125211 }
126212
127213 private static class MockGPUInfoProvider implements GPUInfoProvider {
214+ private final LongSupplier freeMemorySupplier ;
215+
216+ MockGPUInfoProvider (LongSupplier freeMemorySupplier ) {
217+ this .freeMemorySupplier = freeMemorySupplier ;
218+ }
219+
128220 @ Override
129221 public List <GPUInfo > availableGPUs () throws Throwable {
130222 throw new UnsupportedOperationException ();
@@ -137,7 +229,7 @@ public List<GPUInfo> compatibleGPUs() throws Throwable {
137229
138230 @ Override
139231 public CuVSResourcesInfo getCurrentInfo (CuVSResources cuVSResources ) {
140- return new CuVSResourcesInfo (256L * 1024 * 1024 , 2048L * 1024 * 1024 );
232+ return new CuVSResourcesInfo (freeMemorySupplier . getAsLong (), TOTAL_DEVICE_MEMORY_IN_BYTES );
141233 }
142234 }
143235}
0 commit comments