99import static org .opensearch .ml .common .CommonValue .MASTER_KEY ;
1010import static org .opensearch .ml .common .CommonValue .ML_CONFIG_INDEX ;
1111import static org .opensearch .ml .common .CommonValue .ML_MODEL_INDEX ;
12+ import static org .opensearch .ml .utils .RestActionUtils .getAllNodes ;
1213
1314import java .time .Instant ;
1415import java .util .ArrayList ;
4142import org .opensearch .ml .common .transport .sync .MLSyncUpInput ;
4243import org .opensearch .ml .common .transport .sync .MLSyncUpNodeResponse ;
4344import org .opensearch .ml .common .transport .sync .MLSyncUpNodesRequest ;
44- import org .opensearch .ml .common .transport .undeploy .MLUndeployModelAction ;
45- import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesRequest ;
45+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelNodesResponse ;
46+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsAction ;
47+ import org .opensearch .ml .common .transport .undeploy .MLUndeployModelsRequest ;
4648import org .opensearch .ml .engine .encryptor .Encryptor ;
4749import org .opensearch .ml .engine .indices .MLIndicesHandler ;
4850import org .opensearch .search .SearchHit ;
@@ -97,6 +99,14 @@ public void run() {
9799 // gather running model/tasks on nodes
98100 client .execute (MLSyncUpAction .INSTANCE , gatherInfoRequest , ActionListener .wrap (r -> {
99101 List <MLSyncUpNodeResponse > responses = r .getNodes ();
102+ if (r .failures () != null && r .failures ().size () != 0 ) {
103+ log
104+ .debug (
105+ "Received {} failures in the sync up response on nodes. Error messages are {}" ,
106+ r .failures ().size (),
107+ r .failures ().stream ().map (Exception ::getMessage ).collect (Collectors .joining (", " ))
108+ );
109+ }
100110 // key is model id, value is set of worker node ids
101111 Map <String , Set <String >> modelWorkerNodes = new HashMap <>();
102112 // key is task id, value is set of worker node ids
@@ -143,7 +153,6 @@ public void run() {
143153 if (modelWorkerNodes .containsKey (modelId )
144154 && expiredModelToNodes .get (modelId ).size () == modelWorkerNodes .get (modelId ).size ()) {
145155 // this model has expired in all the nodes
146- modelWorkerNodes .remove (modelId );
147156 modelsToUndeploy .add (modelId );
148157 }
149158 }
@@ -168,37 +177,44 @@ public void run() {
168177 MLSyncUpInput syncUpInput = inputBuilder .build ();
169178 MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest (allNodes , syncUpInput );
170179 // sync up running model/tasks on nodes
171- client
172- .execute (
173- MLSyncUpAction .INSTANCE ,
174- syncUpRequest ,
175- ActionListener .wrap (re -> { log .debug ("sync model routing job finished" ); }, ex -> {
176- log .error ("Failed to sync model routing" , ex );
177- })
178- );
179- // Undeploy expired models
180- undeployExpiredModels (modelsToUndeploy , modelWorkerNodes );
180+ client .execute (MLSyncUpAction .INSTANCE , syncUpRequest , ActionListener .wrap (re -> {
181+ log .debug ("sync model routing job finished" );
182+ if (!modelsToUndeploy .isEmpty ()) {
183+ // Undeploy expired models
184+ undeployExpiredModels (modelsToUndeploy , modelWorkerNodes , deployingModels );
185+ return ;
186+ }
187+ // refresh model status
188+ mlIndicesHandler
189+ .initModelIndexIfAbsent (ActionListener .wrap (res -> { refreshModelState (modelWorkerNodes , deployingModels ); }, e -> {
190+ log .error ("Failed to init model index" , e );
191+ }));
192+ }, ex -> { log .error ("Failed to sync model routing" , ex ); }));
193+ }, e -> { log .error ("Failed to sync model routing" , e ); }));
194+ }
195+
196+ private void undeployExpiredModels (
197+ Set <String > expiredModels ,
198+ Map <String , Set <String >> modelWorkerNodes ,
199+ Map <String , Set <String >> deployingModels
200+ ) {
201+ String [] targetNodeIds = getAllNodes (clusterService );
202+ MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest (
203+ expiredModels .toArray (new String [expiredModels .size ()]),
204+ targetNodeIds
205+ );
206+
207+ client .execute (MLUndeployModelsAction .INSTANCE , mlUndeployModelsRequest , ActionListener .wrap (r -> {
208+ MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r .getResponse ();
209+ if (mlUndeployModelNodesResponse .failures () != null && mlUndeployModelNodesResponse .failures ().size () != 0 ) {
210+ log .debug ("Received failures in undeploying expired models" , mlUndeployModelNodesResponse .failures ());
211+ }
181212
182- // refresh model status
183213 mlIndicesHandler
184214 .initModelIndexIfAbsent (ActionListener .wrap (res -> { refreshModelState (modelWorkerNodes , deployingModels ); }, e -> {
185215 log .error ("Failed to init model index" , e );
186216 }));
187- }, e -> { log .error ("Failed to sync model routing" , e ); }));
188- }
189-
190- private void undeployExpiredModels (Set <String > expiredModels , Map <String , Set <String >> modelWorkerNodes ) {
191- expiredModels .forEach (modelId -> {
192- String [] targetNodeIds = modelWorkerNodes .keySet ().toArray (new String [0 ]);
193-
194- MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest (
195- targetNodeIds ,
196- new String [] { modelId }
197- );
198- client .execute (MLUndeployModelAction .INSTANCE , mlUndeployModelNodesRequest , ActionListener .wrap (r -> {
199- log .debug ("model {} is un_deployed" , modelId );
200- }, e -> { log .error ("Failed to undeploy model {}" , modelId , e ); }));
201- });
217+ }, e -> { log .error ("Failed to undeploy models {}" , expiredModels , e ); }));
202218 }
203219
204220 @ VisibleForTesting
0 commit comments