3636import org .junit .After ;
3737import org .junit .Before ;
3838
39- import java .io .IOException ;
4039import java .time .Instant ;
40+ import java .time .temporal .ChronoUnit ;
4141import java .util .List ;
4242import java .util .Map ;
4343import java .util .Set ;
@@ -114,7 +114,12 @@ private ClusterState getClusterState(int numAllocations) {
114114 return clusterState ;
115115 }
116116
117- private GetDeploymentStatsAction .Response getDeploymentStatsResponse (int numAllocations , int inferenceCount , double latency ) {
117+ private GetDeploymentStatsAction .Response getDeploymentStatsResponse (
118+ int numAllocations ,
119+ int inferenceCount ,
120+ double latency ,
121+ boolean recentStartup
122+ ) {
118123 return new GetDeploymentStatsAction .Response (
119124 List .of (),
120125 List .of (),
@@ -127,7 +132,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
127132 new AdaptiveAllocationsSettings (true , null , null ),
128133 1024 ,
129134 ByteSizeValue .ZERO ,
130- Instant .now (),
135+ Instant .now (). minus ( 1 , ChronoUnit . DAYS ) ,
131136 List .of (
132137 AssignmentStats .NodeStats .forStartedState (
133138 randomBoolean () ? DiscoveryNodeUtils .create ("node_1" ) : null ,
@@ -140,7 +145,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
140145 0 ,
141146 0 ,
142147 Instant .now (),
143- Instant .now (),
148+ recentStartup ? Instant .now () : Instant . now (). minus ( 1 , ChronoUnit . HOURS ),
144149 1 ,
145150 numAllocations ,
146151 inferenceCount ,
@@ -156,7 +161,7 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllo
156161 );
157162 }
158163
159- public void test () throws IOException {
164+ public void test_scaleUp () {
160165 // Initialize the cluster with a deployment with 1 allocation.
161166 ClusterState clusterState = getClusterState (1 );
162167 when (clusterService .state ()).thenReturn (clusterState );
@@ -168,7 +173,9 @@ public void test() throws IOException {
168173 inferenceAuditor ,
169174 meterRegistry ,
170175 true ,
171- 1
176+ 1 ,
177+ 60 ,
178+ 60_000
172179 );
173180 service .start ();
174181
@@ -182,7 +189,7 @@ public void test() throws IOException {
182189 doAnswer (invocationOnMock -> {
183190 @ SuppressWarnings ("unchecked" )
184191 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
185- listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 ));
192+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false ));
186193 return Void .TYPE ;
187194 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
188195
@@ -198,7 +205,7 @@ public void test() throws IOException {
198205 doAnswer (invocationOnMock -> {
199206 @ SuppressWarnings ("unchecked" )
200207 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
201- listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 ));
208+ listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 , false ));
202209 return Void .TYPE ;
203210 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
204211 doAnswer (invocationOnMock -> {
@@ -229,7 +236,137 @@ public void test() throws IOException {
229236 doAnswer (invocationOnMock -> {
230237 @ SuppressWarnings ("unchecked" )
231238 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
232- listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 ));
239+ listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 , false ));
240+ return Void .TYPE ;
241+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
242+ doAnswer (invocationOnMock -> {
243+ @ SuppressWarnings ("unchecked" )
244+ var listener = (ActionListener <CreateTrainedModelAssignmentAction .Response >) invocationOnMock .getArguments ()[2 ];
245+ listener .onResponse (null );
246+ return Void .TYPE ;
247+ }).when (client ).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), any (), any ());
248+
249+ safeSleep (1000 );
250+
251+ verify (client , times (1 )).threadPool ();
252+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
253+ verifyNoMoreInteractions (client , clusterService );
254+
255+ service .stop ();
256+ }
257+
258+ public void test_scaleDownToZero_whenNoRequests () {
259+ // Initialize the cluster with a deployment with 1 allocation.
260+ ClusterState clusterState = getClusterState (1 );
261+ when (clusterService .state ()).thenReturn (clusterState );
262+
263+ AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
264+ threadPool ,
265+ clusterService ,
266+ client ,
267+ inferenceAuditor ,
268+ meterRegistry ,
269+ true ,
270+ 1 ,
271+ 1 ,
272+ 2_000
273+ );
274+ service .start ();
275+
276+ verify (clusterService ).state ();
277+ verify (clusterService ).addListener (same (service ));
278+ verifyNoMoreInteractions (client , clusterService );
279+ reset (client , clusterService );
280+
281+ // First cycle: 1 inference request, so no need for scaling.
282+ when (client .threadPool ()).thenReturn (threadPool );
283+ doAnswer (invocationOnMock -> {
284+ @ SuppressWarnings ("unchecked" )
285+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
286+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false ));
287+ return Void .TYPE ;
288+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
289+
290+ safeSleep (1200 );
291+
292+ verify (client , times (1 )).threadPool ();
293+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
294+ verifyNoMoreInteractions (client , clusterService );
295+ reset (client , clusterService );
296+
297+ // Second cycle: 0 inference requests for 1 second, so scale down to 0 allocations.
298+ when (client .threadPool ()).thenReturn (threadPool );
299+ doAnswer (invocationOnMock -> {
300+ @ SuppressWarnings ("unchecked" )
301+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
302+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
303+ return Void .TYPE ;
304+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
305+ doAnswer (invocationOnMock -> {
306+ @ SuppressWarnings ("unchecked" )
307+ var listener = (ActionListener <CreateTrainedModelAssignmentAction .Response >) invocationOnMock .getArguments ()[2 ];
308+ listener .onResponse (null );
309+ return Void .TYPE ;
310+ }).when (client ).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), any (), any ());
311+
312+ safeSleep (1000 );
313+
314+ verify (client , times (2 )).threadPool ();
315+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
316+ var updateRequest = new UpdateTrainedModelDeploymentAction .Request ("test-deployment" );
317+ updateRequest .setNumberOfAllocations (0 );
318+ updateRequest .setIsInternal (true );
319+ verify (client , times (1 )).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), eq (updateRequest ), any ());
320+ verifyNoMoreInteractions (client , clusterService );
321+
322+ service .stop ();
323+ }
324+
325+ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode () {
326+ // Initialize the cluster with a deployment with 1 allocation.
327+ ClusterState clusterState = getClusterState (1 );
328+ when (clusterService .state ()).thenReturn (clusterState );
329+
330+ AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
331+ threadPool ,
332+ clusterService ,
333+ client ,
334+ inferenceAuditor ,
335+ meterRegistry ,
336+ true ,
337+ 1 ,
338+ 1 ,
339+ 2_000
340+ );
341+ service .start ();
342+
343+ verify (clusterService ).state ();
344+ verify (clusterService ).addListener (same (service ));
345+ verifyNoMoreInteractions (client , clusterService );
346+ reset (client , clusterService );
347+
348+ // First cycle: 1 inference request, so no need for scaling.
349+ when (client .threadPool ()).thenReturn (threadPool );
350+ doAnswer (invocationOnMock -> {
351+ @ SuppressWarnings ("unchecked" )
352+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
353+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , true ));
354+ return Void .TYPE ;
355+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
356+
357+ safeSleep (1200 );
358+
359+ verify (client , times (1 )).threadPool ();
360+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
361+ verifyNoMoreInteractions (client , clusterService );
362+ reset (client , clusterService );
363+
364+ // Second cycle: 0 inference requests for 1 second, but a recent scale up by another node.
365+ when (client .threadPool ()).thenReturn (threadPool );
366+ doAnswer (invocationOnMock -> {
367+ @ SuppressWarnings ("unchecked" )
368+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
369+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , true ));
233370 return Void .TYPE ;
234371 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
235372 doAnswer (invocationOnMock -> {
@@ -244,6 +381,32 @@ public void test() throws IOException {
244381 verify (client , times (1 )).threadPool ();
245382 verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
246383 verifyNoMoreInteractions (client , clusterService );
384+ reset (client , clusterService );
385+
386+ // Third cycle: 0 inference requests for 1 second and no recent scale up, so scale down to 0 allocations.
387+ when (client .threadPool ()).thenReturn (threadPool );
388+ doAnswer (invocationOnMock -> {
389+ @ SuppressWarnings ("unchecked" )
390+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
391+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
392+ return Void .TYPE ;
393+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
394+ doAnswer (invocationOnMock -> {
395+ @ SuppressWarnings ("unchecked" )
396+ var listener = (ActionListener <CreateTrainedModelAssignmentAction .Response >) invocationOnMock .getArguments ()[2 ];
397+ listener .onResponse (null );
398+ return Void .TYPE ;
399+ }).when (client ).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), any (), any ());
400+
401+ safeSleep (1000 );
402+
403+ verify (client , times (2 )).threadPool ();
404+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
405+ var updateRequest = new UpdateTrainedModelDeploymentAction .Request ("test-deployment" );
406+ updateRequest .setNumberOfAllocations (0 );
407+ updateRequest .setIsInternal (true );
408+ verify (client , times (1 )).execute (eq (UpdateTrainedModelDeploymentAction .INSTANCE ), eq (updateRequest ), any ());
409+ verifyNoMoreInteractions (client , clusterService );
247410
248411 service .stop ();
249412 }
@@ -256,7 +419,9 @@ public void testMaybeStartAllocation() {
256419 inferenceAuditor ,
257420 meterRegistry ,
258421 true ,
259- 1
422+ 1 ,
423+ 60 ,
424+ 60_000
260425 );
261426
262427 when (client .threadPool ()).thenReturn (threadPool );
@@ -289,7 +454,9 @@ public void testMaybeStartAllocation_BlocksMultipleRequests() throws Exception {
289454 inferenceAuditor ,
290455 meterRegistry ,
291456 true ,
292- 1
457+ 1 ,
458+ 60 ,
459+ 60_000
293460 );
294461
295462 var latch = new CountDownLatch (1 );
0 commit comments