3636import  org .junit .After ;
3737import  org .junit .Before ;
3838
39- import  java .io .IOException ;
4039import  java .time .Instant ;
40+ import  java .time .temporal .ChronoUnit ;
4141import  java .util .List ;
4242import  java .util .Map ;
4343import  java .util .Set ;
@@ -114,7 +114,12 @@ private ClusterState getClusterState(int numAllocations) {
114114        return  clusterState ;
115115    }
116116
117-     private  GetDeploymentStatsAction .Response  getDeploymentStatsResponse (int  numAllocations , int  inferenceCount , double  latency ) {
117+     private  GetDeploymentStatsAction .Response  getDeploymentStatsResponse (
118+         int  numAllocations ,
119+         int  inferenceCount ,
120+         double  latency ,
121+         boolean  recentStartup 
122+     ) {
118123        return  new  GetDeploymentStatsAction .Response (
119124            List .of (),
120125            List .of (),
@@ -127,7 +132,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
127132                    new  AdaptiveAllocationsSettings (true , null , null ),
128133                    1024 ,
129134                    ByteSizeValue .ZERO ,
130-                     Instant .now (),
135+                     Instant .now (). minus ( 1 ,  ChronoUnit . DAYS ) ,
131136                    List .of (
132137                        AssignmentStats .NodeStats .forStartedState (
133138                            DiscoveryNodeUtils .create ("node_1" ),
@@ -140,7 +145,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
140145                            0 ,
141146                            0 ,
142147                            Instant .now (),
143-                             Instant .now (),
148+                             recentStartup  ?  Instant .now () :  Instant . now (). minus ( 1 ,  ChronoUnit . HOURS ),
144149                            1 ,
145150                            numAllocations ,
146151                            inferenceCount ,
@@ -156,7 +161,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
156161        );
157162    }
158163
159-     public  void  test ()  throws   IOException  {
164+     public  void  test_scaleUp ()  {
160165        // Initialize the cluster with a deployment with 1 allocation. 
161166        ClusterState  clusterState  = getClusterState (1 );
162167        when (clusterService .state ()).thenReturn (clusterState );
@@ -168,7 +173,9 @@ public void test() throws IOException {
168173            inferenceAuditor ,
169174            meterRegistry ,
170175            true ,
171-             1 
176+             1 ,
177+             60 ,
178+             60_000 
172179        );
173180        service .start ();
174181
@@ -182,7 +189,7 @@ public void test() throws IOException {
182189        doAnswer (invocationOnMock  -> {
183190            @ SuppressWarnings ("unchecked" )
184191            var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
185-             listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 ));
192+             listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 ,  false ));
186193            return  Void .TYPE ;
187194        }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
188195
@@ -198,7 +205,7 @@ public void test() throws IOException {
198205        doAnswer (invocationOnMock  -> {
199206            @ SuppressWarnings ("unchecked" )
200207            var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
201-             listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 ));
208+             listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 ,  false ));
202209            return  Void .TYPE ;
203210        }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
204211        doAnswer (invocationOnMock  -> {
@@ -229,7 +236,137 @@ public void test() throws IOException {
229236        doAnswer (invocationOnMock  -> {
230237            @ SuppressWarnings ("unchecked" )
231238            var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
232-             listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 ));
239+             listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 , false ));
240+             return  Void .TYPE ;
241+         }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
242+         doAnswer (invocationOnMock  -> {
243+             @ SuppressWarnings ("unchecked" )
244+             var  listener  = (ActionListener <CreateTrainedModelAssignmentAction .Response >) invocationOnMock .getArguments ()[2 ];
245+             listener .onResponse (null );
246+             return  Void .TYPE ;
247+         }).when (client ).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), any (), any ());
248+ 
249+         safeSleep (1000 );
250+ 
251+         verify (client , times (1 )).threadPool ();
252+         verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
253+         verifyNoMoreInteractions (client , clusterService );
254+ 
255+         service .stop ();
256+     }
257+ 
258+     public  void  test_scaleDownToZero_whenNoRequests () {
259+         // Initialize the cluster with a deployment with 1 allocation. 
260+         ClusterState  clusterState  = getClusterState (1 );
261+         when (clusterService .state ()).thenReturn (clusterState );
262+ 
263+         AdaptiveAllocationsScalerService  service  = new  AdaptiveAllocationsScalerService (
264+             threadPool ,
265+             clusterService ,
266+             client ,
267+             inferenceAuditor ,
268+             meterRegistry ,
269+             true ,
270+             1 ,
271+             1 ,
272+             2_000 
273+         );
274+         service .start ();
275+ 
276+         verify (clusterService ).state ();
277+         verify (clusterService ).addListener (same (service ));
278+         verifyNoMoreInteractions (client , clusterService );
279+         reset (client , clusterService );
280+ 
281+         // First cycle: 1 inference request, so no need for scaling. 
282+         when (client .threadPool ()).thenReturn (threadPool );
283+         doAnswer (invocationOnMock  -> {
284+             @ SuppressWarnings ("unchecked" )
285+             var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
286+             listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false ));
287+             return  Void .TYPE ;
288+         }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
289+ 
290+         safeSleep (1200 );
291+ 
292+         verify (client , times (1 )).threadPool ();
293+         verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
294+         verifyNoMoreInteractions (client , clusterService );
295+         reset (client , clusterService );
296+ 
297+         // Second cycle: 0 inference requests for 1 second, so scale down to 0 allocations. 
298+         when (client .threadPool ()).thenReturn (threadPool );
299+         doAnswer (invocationOnMock  -> {
300+             @ SuppressWarnings ("unchecked" )
301+             var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
302+             listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
303+             return  Void .TYPE ;
304+         }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
305+         doAnswer (invocationOnMock  -> {
306+             @ SuppressWarnings ("unchecked" )
307+             var  listener  = (ActionListener <CreateTrainedModelAssignmentAction .Response >) invocationOnMock .getArguments ()[2 ];
308+             listener .onResponse (null );
309+             return  Void .TYPE ;
310+         }).when (client ).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), any (), any ());
311+ 
312+         safeSleep (1000 );
313+ 
314+         verify (client , times (2 )).threadPool ();
315+         verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
316+         var  updateRequest  = new  UpdateTrainedModelDeploymentAction .Request ("test-deployment" );
317+         updateRequest .setNumberOfAllocations (0 );
318+         updateRequest .setIsInternal (true );
319+         verify (client , times (1 )).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), eq (updateRequest ), any ());
320+         verifyNoMoreInteractions (client , clusterService );
321+ 
322+         service .stop ();
323+     }
324+ 
325+     public  void  test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode () {
326+         // Initialize the cluster with a deployment with 1 allocation. 
327+         ClusterState  clusterState  = getClusterState (1 );
328+         when (clusterService .state ()).thenReturn (clusterState );
329+ 
330+         AdaptiveAllocationsScalerService  service  = new  AdaptiveAllocationsScalerService (
331+             threadPool ,
332+             clusterService ,
333+             client ,
334+             inferenceAuditor ,
335+             meterRegistry ,
336+             true ,
337+             1 ,
338+             1 ,
339+             2_000 
340+         );
341+         service .start ();
342+ 
343+         verify (clusterService ).state ();
344+         verify (clusterService ).addListener (same (service ));
345+         verifyNoMoreInteractions (client , clusterService );
346+         reset (client , clusterService );
347+ 
348+         // First cycle: 1 inference request, so no need for scaling. 
349+         when (client .threadPool ()).thenReturn (threadPool );
350+         doAnswer (invocationOnMock  -> {
351+             @ SuppressWarnings ("unchecked" )
352+             var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
353+             listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , true ));
354+             return  Void .TYPE ;
355+         }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
356+ 
357+         safeSleep (1200 );
358+ 
359+         verify (client , times (1 )).threadPool ();
360+         verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
361+         verifyNoMoreInteractions (client , clusterService );
362+         reset (client , clusterService );
363+ 
364+         // Second cycle: 0 inference requests for 1 second, but a recent scale up by another node. 
365+         when (client .threadPool ()).thenReturn (threadPool );
366+         doAnswer (invocationOnMock  -> {
367+             @ SuppressWarnings ("unchecked" )
368+             var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
369+             listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , true ));
233370            return  Void .TYPE ;
234371        }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
235372        doAnswer (invocationOnMock  -> {
@@ -244,6 +381,32 @@ public void test() throws IOException {
244381        verify (client , times (1 )).threadPool ();
245382        verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
246383        verifyNoMoreInteractions (client , clusterService );
384+         reset (client , clusterService );
385+ 
386+         // Third cycle: 0 inference requests for 1 second and no recent scale up, so scale down to 0 allocations. 
387+         when (client .threadPool ()).thenReturn (threadPool );
388+         doAnswer (invocationOnMock  -> {
389+             @ SuppressWarnings ("unchecked" )
390+             var  listener  = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
391+             listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
392+             return  Void .TYPE ;
393+         }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new  GetDeploymentStatsAction .Request ("test-deployment" )), any ());
394+         doAnswer (invocationOnMock  -> {
395+             @ SuppressWarnings ("unchecked" )
396+             var  listener  = (ActionListener <CreateTrainedModelAssignmentAction .Response >) invocationOnMock .getArguments ()[2 ];
397+             listener .onResponse (null );
398+             return  Void .TYPE ;
399+         }).when (client ).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), any (), any ());
400+ 
401+         safeSleep (1000 );
402+ 
403+         verify (client , times (2 )).threadPool ();
404+         verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
405+         var  updateRequest  = new  UpdateTrainedModelDeploymentAction .Request ("test-deployment" );
406+         updateRequest .setNumberOfAllocations (0 );
407+         updateRequest .setIsInternal (true );
408+         verify (client , times (1 )).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), eq (updateRequest ), any ());
409+         verifyNoMoreInteractions (client , clusterService );
247410
248411        service .stop ();
249412    }
@@ -256,7 +419,9 @@ public void testMaybeStartAllocation() {
256419            inferenceAuditor ,
257420            meterRegistry ,
258421            true ,
259-             1 
422+             1 ,
423+             60 ,
424+             60_000 
260425        );
261426
262427        when (client .threadPool ()).thenReturn (threadPool );
@@ -289,7 +454,9 @@ public void testMaybeStartAllocation_BlocksMultipleRequests() throws Exception {
289454            inferenceAuditor ,
290455            meterRegistry ,
291456            true ,
292-             1 
457+             1 ,
458+             60 ,
459+             60_000 
293460        );
294461
295462        var  latch  = new  CountDownLatch (1 );
0 commit comments