1616import org .elasticsearch .client .internal .Client ;
1717import org .elasticsearch .client .internal .OriginSettingClient ;
1818import org .elasticsearch .client .internal .ParentTaskAssigningClient ;
19+ import org .elasticsearch .cluster .ClusterName ;
1920import org .elasticsearch .cluster .ClusterState ;
2021import org .elasticsearch .cluster .block .ClusterBlockException ;
2122import org .elasticsearch .cluster .block .ClusterBlockLevel ;
3536import org .elasticsearch .xpack .core .ml .action .MlMemoryAction .Response .MlMemoryStats ;
3637import org .elasticsearch .xpack .core .ml .action .TrainedModelCacheInfoAction ;
3738import org .elasticsearch .xpack .core .ml .action .TrainedModelCacheInfoAction .Response .CacheInfo ;
39+ import org .elasticsearch .xpack .core .ml .inference .assignment .TrainedModelAssignmentMetadata ;
3840import org .elasticsearch .xpack .ml .job .NodeLoad ;
3941import org .elasticsearch .xpack .ml .job .NodeLoadDetector ;
4042import org .elasticsearch .xpack .ml .process .MlMemoryTracker ;
@@ -88,6 +90,9 @@ protected void masterOperation(
8890
8991 ClusterSettings clusterSettings = clusterService .getClusterSettings ();
9092
93+ var clusterName = state .getClusterName ();
94+ var trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata .fromState (state );
95+ PersistentTasksCustomMetadata persistentTasksCustomMetadata = state .getMetadata ().custom (PersistentTasksCustomMetadata .TYPE );
9196 // Resolve the node specification to some concrete nodes
9297 String [] nodeIds = state .nodes ().resolveNodes (request .getNodeId ());
9398
@@ -116,7 +121,9 @@ protected void masterOperation(
116121 trainedModelCacheInfoRequest ,
117122 delegate2 .delegateFailureAndWrap (
118123 (l , trainedModelCacheInfoResponse ) -> handleResponses (
119- state ,
124+ clusterName ,
125+ persistentTasksCustomMetadata ,
126+ trainedModelAssignmentMetadata ,
120127 clusterSettings ,
121128 nodesStatsResponse ,
122129 trainedModelCacheInfoResponse ,
@@ -131,12 +138,14 @@ protected void masterOperation(
131138 if (memoryTracker .isEverRefreshed ()) {
132139 memoryTrackerRefreshListener .onResponse (null );
133140 } else {
134- memoryTracker .refresh (state . getMetadata (). custom ( PersistentTasksCustomMetadata . TYPE ) , memoryTrackerRefreshListener );
141+ memoryTracker .refresh (persistentTasksCustomMetadata , memoryTrackerRefreshListener );
135142 }
136143 }
137144
138145 void handleResponses (
139- ClusterState state ,
146+ ClusterName clusterName ,
147+ PersistentTasksCustomMetadata persistentTasks ,
148+ TrainedModelAssignmentMetadata assignmentMetadata ,
140149 ClusterSettings clusterSettings ,
141150 NodesStatsResponse nodesStatsResponse ,
142151 TrainedModelCacheInfoAction .Response trainedModelCacheInfoResponse ,
@@ -174,7 +183,8 @@ void handleResponses(
174183 ByteSizeValue mlNativeInference ;
175184 if (node .getRoles ().contains (DiscoveryNodeRole .ML_ROLE )) {
176185 NodeLoad nodeLoad = nodeLoadDetector .detectNodeLoad (
177- state ,
186+ persistentTasks ,
187+ assignmentMetadata ,
178188 node ,
179189 maxOpenJobsPerNode ,
180190 maxMachineMemoryPercent ,
@@ -220,7 +230,7 @@ void handleResponses(
220230 );
221231 }
222232
223- listener .onResponse (new MlMemoryAction .Response (state . getClusterName () , nodeResponses , failures ));
233+ listener .onResponse (new MlMemoryAction .Response (clusterName , nodeResponses , failures ));
224234 }
225235
226236 @ Override
0 commit comments