2727import org .elasticsearch .xpack .core .ml .action .StartTrainedModelDeploymentAction ;
2828import org .elasticsearch .xpack .core .ml .action .UpdateTrainedModelDeploymentAction ;
2929import org .elasticsearch .xpack .core .ml .inference .assignment .AdaptiveAllocationsSettings ;
30+ import org .elasticsearch .xpack .core .ml .inference .assignment .AssignmentState ;
3031import org .elasticsearch .xpack .core .ml .inference .assignment .AssignmentStats ;
3132import org .elasticsearch .xpack .core .ml .inference .assignment .Priority ;
3233import org .elasticsearch .xpack .core .ml .inference .assignment .TrainedModelAssignment ;
@@ -85,7 +86,7 @@ public void tearDown() throws Exception {
8586 super .tearDown ();
8687 }
8788
88- private ClusterState getClusterState (int numAllocations ) {
89+ private ClusterState getClusterState (int numAllocations , AssignmentState assignmentState ) {
8990 ClusterState clusterState = mock (ClusterState .class );
9091 Metadata metadata = mock (Metadata .class );
9192 when (clusterState .getMetadata ()).thenReturn (metadata );
@@ -107,7 +108,7 @@ private ClusterState getClusterState(int numAllocations) {
107108 100_000_000
108109 ),
109110 new AdaptiveAllocationsSettings (true , null , null )
110- ).build ()
111+ ).setAssignmentState ( assignmentState ). build ()
111112 )
112113 )
113114 );
@@ -118,7 +119,8 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
118119 int numAllocations ,
119120 int inferenceCount ,
120121 double latency ,
121- boolean recentStartup
122+ boolean recentStartup ,
123+ AssignmentState assignmentState
122124 ) {
123125 return new GetDeploymentStatsAction .Response (
124126 List .of (),
@@ -155,15 +157,15 @@ private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
155157 )
156158 ),
157159 Priority .NORMAL
158- )
160+ ). setState ( assignmentState )
159161 ),
160162 0
161163 );
162164 }
163165
164166 public void test_scaleUp () {
165167 // Initialize the cluster with a deployment with 1 allocation.
166- ClusterState clusterState = getClusterState (1 );
168+ ClusterState clusterState = getClusterState (1 , AssignmentState . STARTED );
167169 when (clusterService .state ()).thenReturn (clusterState );
168170
169171 AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
@@ -189,7 +191,7 @@ public void test_scaleUp() {
189191 doAnswer (invocationOnMock -> {
190192 @ SuppressWarnings ("unchecked" )
191193 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
192- listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false ));
194+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false , AssignmentState . STARTED ));
193195 return Void .TYPE ;
194196 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
195197
@@ -205,7 +207,7 @@ public void test_scaleUp() {
205207 doAnswer (invocationOnMock -> {
206208 @ SuppressWarnings ("unchecked" )
207209 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
208- listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 , false ));
210+ listener .onResponse (getDeploymentStatsResponse (1 , 150 , 10.0 , false , AssignmentState . STARTED ));
209211 return Void .TYPE ;
210212 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
211213 doAnswer (invocationOnMock -> {
@@ -226,7 +228,7 @@ public void test_scaleUp() {
226228 verifyNoMoreInteractions (client , clusterService );
227229 reset (client , clusterService );
228230
229- clusterState = getClusterState (2 );
231+ clusterState = getClusterState (2 , AssignmentState . STARTED );
230232 ClusterChangedEvent clusterChangedEvent = mock (ClusterChangedEvent .class );
231233 when (clusterChangedEvent .state ()).thenReturn (clusterState );
232234 service .clusterChanged (clusterChangedEvent );
@@ -236,7 +238,7 @@ public void test_scaleUp() {
236238 doAnswer (invocationOnMock -> {
237239 @ SuppressWarnings ("unchecked" )
238240 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
239- listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 , false ));
241+ listener .onResponse (getDeploymentStatsResponse (2 , 0 , 9.0 , false , AssignmentState . STARTED ));
240242 return Void .TYPE ;
241243 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
242244 doAnswer (invocationOnMock -> {
@@ -257,7 +259,7 @@ public void test_scaleUp() {
257259
258260 public void test_scaleDownToZero_whenNoRequests () {
259261 // Initialize the cluster with a deployment with 1 allocation.
260- ClusterState clusterState = getClusterState (1 );
262+ ClusterState clusterState = getClusterState (1 , AssignmentState . STARTED );
261263 when (clusterService .state ()).thenReturn (clusterState );
262264
263265 AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
@@ -283,7 +285,7 @@ public void test_scaleDownToZero_whenNoRequests() {
283285 doAnswer (invocationOnMock -> {
284286 @ SuppressWarnings ("unchecked" )
285287 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
286- listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false ));
288+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , false , AssignmentState . STARTED ));
287289 return Void .TYPE ;
288290 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
289291
@@ -299,7 +301,7 @@ public void test_scaleDownToZero_whenNoRequests() {
299301 doAnswer (invocationOnMock -> {
300302 @ SuppressWarnings ("unchecked" )
301303 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
302- listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
304+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false , AssignmentState . STARTED ));
303305 return Void .TYPE ;
304306 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
305307 doAnswer (invocationOnMock -> {
@@ -322,9 +324,65 @@ public void test_scaleDownToZero_whenNoRequests() {
322324 service .stop ();
323325 }
324326
327+ public void test_dontScale_whenNotStarted () {
328+ // Initialize the cluster with a deployment with 1 allocation.
329+ ClusterState clusterState = getClusterState (1 , AssignmentState .STARTING );
330+ when (clusterService .state ()).thenReturn (clusterState );
331+
332+ AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
333+ threadPool ,
334+ clusterService ,
335+ client ,
336+ inferenceAuditor ,
337+ meterRegistry ,
338+ true ,
339+ 1 ,
340+ 1 ,
341+ 2_000
342+ );
343+ service .start ();
344+
345+ verify (clusterService ).state ();
346+ verify (clusterService ).addListener (same (service ));
347+ verifyNoMoreInteractions (client , clusterService );
348+ reset (client , clusterService );
349+
350+ // First cycle: many inference requests
351+ when (client .threadPool ()).thenReturn (threadPool );
352+ doAnswer (invocationOnMock -> {
353+ @ SuppressWarnings ("unchecked" )
354+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
355+ listener .onResponse (getDeploymentStatsResponse (1 , 10000 , 10.0 , false , AssignmentState .STARTING ));
356+ return Void .TYPE ;
357+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
358+
359+ safeSleep (1200 );
360+
361+ verify (client , times (1 )).threadPool ();
362+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
363+ verifyNoMoreInteractions (client , clusterService );
364+ reset (client , clusterService );
365+
366+ // Second cycle: again many inference requests
367+ when (client .threadPool ()).thenReturn (threadPool );
368+ doAnswer (invocationOnMock -> {
369+ @ SuppressWarnings ("unchecked" )
370+ var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
371+ listener .onResponse (getDeploymentStatsResponse (1 , 20000 , 10.0 , false , AssignmentState .STARTING ));
372+ return Void .TYPE ;
373+ }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
374+
375+ safeSleep (1200 );
376+
377+ verify (client , times (1 )).threadPool ();
378+ verify (client , times (1 )).execute (eq (GetDeploymentStatsAction .INSTANCE ), any (), any ());
379+ verifyNoMoreInteractions (client , clusterService );
380+ service .stop ();
381+ }
382+
325383 public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode () {
326384 // Initialize the cluster with a deployment with 1 allocation.
327- ClusterState clusterState = getClusterState (1 );
385+ ClusterState clusterState = getClusterState (1 , AssignmentState . STARTED );
328386 when (clusterService .state ()).thenReturn (clusterState );
329387
330388 AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService (
@@ -350,7 +408,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
350408 doAnswer (invocationOnMock -> {
351409 @ SuppressWarnings ("unchecked" )
352410 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
353- listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , true ));
411+ listener .onResponse (getDeploymentStatsResponse (1 , 1 , 11.0 , true , AssignmentState . STARTED ));
354412 return Void .TYPE ;
355413 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
356414
@@ -366,7 +424,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
366424 doAnswer (invocationOnMock -> {
367425 @ SuppressWarnings ("unchecked" )
368426 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
369- listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , true ));
427+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , true , AssignmentState . STARTED ));
370428 return Void .TYPE ;
371429 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
372430 doAnswer (invocationOnMock -> {
@@ -388,7 +446,7 @@ public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
388446 doAnswer (invocationOnMock -> {
389447 @ SuppressWarnings ("unchecked" )
390448 var listener = (ActionListener <GetDeploymentStatsAction .Response >) invocationOnMock .getArguments ()[2 ];
391- listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false ));
449+ listener .onResponse (getDeploymentStatsResponse (1 , 0 , 10.0 , false , AssignmentState . STARTED ));
392450 return Void .TYPE ;
393451 }).when (client ).execute (eq (GetDeploymentStatsAction .INSTANCE ), eq (new GetDeploymentStatsAction .Request ("test-deployment" )), any ());
394452 doAnswer (invocationOnMock -> {
0 commit comments