1212import org .elasticsearch .ResourceNotFoundException ;
1313import org .elasticsearch .action .ActionListener ;
1414import org .elasticsearch .action .search .SearchPhaseExecutionException ;
15- import org .elasticsearch .action .support .PlainActionFuture ;
16- import org .elasticsearch .action .support .UnsafePlainActionFuture ;
15+ import org .elasticsearch .action .support .SubscribableListener ;
1716import org .elasticsearch .action .support .master .AcknowledgedResponse ;
1817import org .elasticsearch .cluster .ClusterChangedEvent ;
1918import org .elasticsearch .cluster .ClusterState ;
5352import org .elasticsearch .xpack .ml .inference .deployment .TrainedModelDeploymentTask ;
5453import org .elasticsearch .xpack .ml .task .AbstractJobPersistentTasksExecutor ;
5554
56- import java .util .ArrayDeque ;
5755import java .util .ArrayList ;
5856import java .util .Collections ;
5957import java .util .Deque ;
@@ -154,26 +152,38 @@ public void beforeStop() {
154152 this .expressionResolver = expressionResolver ;
155153 }
156154
157- public void start () {
155+ void start () {
158156 stopped = false ;
159- scheduledFuture = threadPool .scheduleWithFixedDelay (
160- this ::loadQueuedModels ,
161- MODEL_LOADING_CHECK_INTERVAL ,
162- threadPool .executor (MachineLearning .UTILITY_THREAD_POOL_NAME )
163- );
157+ schedule (false );
164158 }
165159
166- public void stop () {
160+ private void schedule (boolean runImmediately ) {
161+ if (stopped ) {
162+ // do not schedule when stopped
163+ return ;
164+ }
165+
166+ var rescheduleListener = ActionListener .wrap (this ::schedule , e -> this .schedule (false ));
167+ Runnable loadQueuedModels = () -> loadQueuedModels (rescheduleListener );
168+ var executor = threadPool .executor (MachineLearning .UTILITY_THREAD_POOL_NAME );
169+
170+ if (runImmediately ) {
171+ executor .execute (loadQueuedModels );
172+ } else {
173+ scheduledFuture = threadPool .schedule (loadQueuedModels , MODEL_LOADING_CHECK_INTERVAL , executor );
174+ }
175+ }
176+
177+ void stop () {
167178 stopped = true ;
168179 ThreadPool .Cancellable cancellable = this .scheduledFuture ;
169180 if (cancellable != null ) {
170181 cancellable .cancel ();
171182 }
172183 }
173184
174- void loadQueuedModels () {
175- TrainedModelDeploymentTask loadingTask ;
176- if (loadingModels .isEmpty ()) {
185+ void loadQueuedModels (ActionListener <Boolean > rescheduleImmediately ) {
186+ if (stopped ) {
177187 return ;
178188 }
179189 if (latestState != null ) {
@@ -188,39 +198,49 @@ void loadQueuedModels() {
188198 );
189199 if (unassignedIndices .size () > 0 ) {
190200 logger .trace ("not loading models as indices {} primary shards are unassigned" , unassignedIndices );
201+ rescheduleImmediately .onResponse (false );
191202 return ;
192203 }
193204 }
194- logger .trace ("attempting to load all currently queued models" );
195- // NOTE: As soon as this method exits, the timer for the scheduler starts ticking
196- Deque <TrainedModelDeploymentTask > loadingToRetry = new ArrayDeque <>();
197- while ((loadingTask = loadingModels .poll ()) != null ) {
198- final String deploymentId = loadingTask .getDeploymentId ();
199- if (loadingTask .isStopped ()) {
200- if (logger .isTraceEnabled ()) {
201- String reason = loadingTask .stoppedReason ().orElse ("_unknown_" );
202- logger .trace ("[{}] attempted to load stopped task with reason [{}]" , deploymentId , reason );
203- }
204- continue ;
205+
206+ var loadingTask = loadingModels .poll ();
207+ if (loadingTask == null ) {
208+ rescheduleImmediately .onResponse (false );
209+ return ;
210+ }
211+
212+ loadModel (loadingTask , ActionListener .wrap (retry -> {
213+ if (retry != null && retry ) {
214+ loadingModels .offer (loadingTask );
215+ // don't reschedule immediately if the next task is the one we just queued, instead wait a bit to retry
216+ rescheduleImmediately .onResponse (loadingModels .peek () != loadingTask );
217+ } else {
218+ rescheduleImmediately .onResponse (loadingModels .isEmpty () == false );
205219 }
206- if (stopped ) {
207- return ;
220+ }, e -> rescheduleImmediately .onResponse (loadingModels .isEmpty () == false )));
221+ }
222+
223+ void loadModel (TrainedModelDeploymentTask loadingTask , ActionListener <Boolean > retryListener ) {
224+ if (loadingTask .isStopped ()) {
225+ if (logger .isTraceEnabled ()) {
226+ logger .trace (
227+ "[{}] attempted to load stopped task with reason [{}]" ,
228+ loadingTask .getDeploymentId (),
229+ loadingTask .stoppedReason ().orElse ("_unknown_" )
230+ );
208231 }
209- final PlainActionFuture <TrainedModelDeploymentTask > listener = new UnsafePlainActionFuture <>(
210- MachineLearning .UTILITY_THREAD_POOL_NAME
211- );
212- try {
213- deploymentManager .startDeployment (loadingTask , listener );
214- // This needs to be synchronous here in the utility thread to keep queueing order
215- TrainedModelDeploymentTask deployedTask = listener .actionGet ();
216- // kicks off asynchronous cluster state update
217- handleLoadSuccess (deployedTask );
218- } catch (Exception ex ) {
232+ retryListener .onResponse (false );
233+ return ;
234+ }
235+ SubscribableListener .<TrainedModelDeploymentTask >newForked (l -> deploymentManager .startDeployment (loadingTask , l ))
236+ .andThen (threadPool .executor (MachineLearning .UTILITY_THREAD_POOL_NAME ), threadPool .getThreadContext (), this ::handleLoadSuccess )
237+ .addListener (retryListener .delegateResponse ((retryL , ex ) -> {
238+ var deploymentId = loadingTask .getDeploymentId ();
219239 logger .warn (() -> "[" + deploymentId + "] Start deployment failed" , ex );
220240 if (ExceptionsHelper .unwrapCause (ex ) instanceof ResourceNotFoundException ) {
221- String modelId = loadingTask .getParams ().getModelId ();
241+ var modelId = loadingTask .getParams ().getModelId ();
222242 logger .debug (() -> "[" + deploymentId + "] Start deployment failed as model [" + modelId + "] was not found" , ex );
223- handleLoadFailure (loadingTask , ExceptionsHelper .missingTrainedModel (modelId , ex ));
243+ handleLoadFailure (loadingTask , ExceptionsHelper .missingTrainedModel (modelId , ex ), retryL );
224244 } else if (ExceptionsHelper .unwrapCause (ex ) instanceof SearchPhaseExecutionException ) {
225245 /*
226246 * This case will not catch the ElasticsearchException generated from the ChunkedTrainedModelRestorer in a scenario
@@ -232,13 +252,11 @@ void loadQueuedModels() {
232252 // A search phase execution failure should be retried, push task back to the queue
233253
234254 // This will cause the entire model to be reloaded (all the chunks)
235- loadingToRetry . add ( loadingTask );
255+ retryL . onResponse ( true );
236256 } else {
237- handleLoadFailure (loadingTask , ex );
257+ handleLoadFailure (loadingTask , ex , retryL );
238258 }
239- }
240- }
241- loadingModels .addAll (loadingToRetry );
259+ }), threadPool .executor (MachineLearning .UTILITY_THREAD_POOL_NAME ), threadPool .getThreadContext ());
242260 }
243261
244262 public void gracefullyStopDeploymentAndNotify (
@@ -680,14 +698,14 @@ void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams)
680698 );
681699 // threadsafe check to verify we are not loading/loaded the model
682700 if (deploymentIdToTask .putIfAbsent (taskParams .getDeploymentId (), task ) == null ) {
683- loadingModels .add (task );
701+ loadingModels .offer (task );
684702 } else {
685703 // If there is already a task for the deployment, unregister the new task
686704 taskManager .unregister (task );
687705 }
688706 }
689707
690- private void handleLoadSuccess (TrainedModelDeploymentTask task ) {
708+ private void handleLoadSuccess (ActionListener < Boolean > retryListener , TrainedModelDeploymentTask task ) {
691709 logger .debug (
692710 () -> "["
693711 + task .getParams ().getDeploymentId ()
@@ -704,13 +722,16 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) {
704722 task .stoppedReason ().orElse ("_unknown_" )
705723 )
706724 );
725+ retryListener .onResponse (false );
707726 return ;
708727 }
709728
710729 updateStoredState (
711730 task .getDeploymentId (),
712731 RoutingInfoUpdate .updateStateAndReason (new RoutingStateAndReason (RoutingState .STARTED , "" )),
713- ActionListener .wrap (r -> logger .debug (() -> "[" + task .getDeploymentId () + "] model loaded and accepting routes" ), e -> {
732+ ActionListener .runAfter (ActionListener .wrap (r -> {
733+ logger .debug (() -> "[" + task .getDeploymentId () + "] model loaded and accepting routes" );
734+ }, e -> {
714735 // This means that either the assignment has been deleted, or this node's particular route has been removed
715736 if (ExceptionsHelper .unwrapCause (e ) instanceof ResourceNotFoundException ) {
716737 logger .debug (
@@ -732,7 +753,7 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) {
732753 e
733754 );
734755 }
735- })
756+ }), () -> retryListener . onResponse ( false ))
736757 );
737758 }
738759
@@ -752,7 +773,7 @@ private void updateStoredState(String deploymentId, RoutingInfoUpdate update, Ac
752773 );
753774 }
754775
755- private void handleLoadFailure (TrainedModelDeploymentTask task , Exception ex ) {
776+ private void handleLoadFailure (TrainedModelDeploymentTask task , Exception ex , ActionListener < Boolean > retryListener ) {
756777 logger .error (() -> "[" + task .getDeploymentId () + "] model [" + task .getParams ().getModelId () + "] failed to load" , ex );
757778 if (task .isStopped ()) {
758779 logger .debug (
@@ -769,14 +790,14 @@ private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) {
769790 Runnable stopTask = () -> stopDeploymentAsync (
770791 task ,
771792 "model failed to load; reason [" + ex .getMessage () + "]" ,
772- ActionListener .noop ( )
793+ ActionListener .running (() -> retryListener . onResponse ( false ) )
773794 );
774795 updateStoredState (
775796 task .getDeploymentId (),
776797 RoutingInfoUpdate .updateStateAndReason (
777798 new RoutingStateAndReason (RoutingState .FAILED , ExceptionsHelper .unwrapCause (ex ).getMessage ())
778799 ),
779- ActionListener .wrap ( r -> stopTask . run (), e -> stopTask . run () )
800+ ActionListener .running ( stopTask )
780801 );
781802 }
782803
0 commit comments