Skip to content

Commit f253965

Browse files
authored
[ML] improve trained model stats API performance (#87978) (#88129)
Previous, get trained model stats API would build every pipeline defined in cluster state. This is problematic when MANY pipelines are defined. Especially if those pipelines take some time to parse (consider GROK). This improvement is part of fixing: #87931
1 parent 56063db commit f253965

File tree

9 files changed

+387
-300
lines changed

9 files changed

+387
-300
lines changed

docs/changelog/87978.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 87978
2+
summary: Improve trained model stats API performance
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
import java.util.concurrent.TimeUnit;
4848
import java.util.stream.Collectors;
4949

50-
import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
5150
import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD;
5251
import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics;
5352
import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL;
@@ -56,6 +55,7 @@
5655
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob;
5756
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts;
5857
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs;
58+
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors;
5959
import static org.hamcrest.Matchers.containsString;
6060
import static org.hamcrest.Matchers.empty;
6161
import static org.hamcrest.Matchers.equalTo;
@@ -132,9 +132,7 @@ public void testMLFeatureReset() throws Exception {
132132
client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_inference_pipeline")).actionGet();
133133
createdPipelines.remove("feature_reset_inference_pipeline");
134134

135-
assertBusy(
136-
() -> assertThat(countNumberInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0))
137-
);
135+
assertBusy(() -> assertThat(countInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0)));
138136
client().execute(ResetFeatureStateAction.INSTANCE, new ResetFeatureStateRequest()).actionGet();
139137
assertBusy(() -> {
140138
List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@
445445
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
446446
import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.RESULTS_INDEX_PREFIX;
447447
import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX;
448-
import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
448+
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.countInferenceProcessors;
449449

450450
public class MachineLearning extends Plugin
451451
implements
@@ -1883,7 +1883,7 @@ public void cleanUpFeature(
18831883

18841884
// validate no pipelines are using machine learning models
18851885
ActionListener<AcknowledgedResponse> afterResetModeSet = ActionListener.wrap(acknowledgedResponse -> {
1886-
int numberInferenceProcessors = countNumberInferenceProcessors(clusterService.state());
1886+
int numberInferenceProcessors = countInferenceProcessors(clusterService.state());
18871887
if (numberInferenceProcessors > 0) {
18881888
unsetResetModeListener.onFailure(
18891889
new RuntimeException(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
*/
77
package org.elasticsearch.xpack.ml.action;
88

9-
import org.elasticsearch.ElasticsearchException;
109
import org.elasticsearch.action.ActionListener;
1110
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
1211
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
@@ -27,10 +26,7 @@
2726
import org.elasticsearch.core.Tuple;
2827
import org.elasticsearch.index.query.QueryBuilder;
2928
import org.elasticsearch.index.query.QueryBuilders;
30-
import org.elasticsearch.ingest.IngestMetadata;
31-
import org.elasticsearch.ingest.IngestService;
3229
import org.elasticsearch.ingest.IngestStats;
33-
import org.elasticsearch.ingest.Pipeline;
3430
import org.elasticsearch.search.SearchHit;
3531
import org.elasticsearch.search.sort.SortOrder;
3632
import org.elasticsearch.tasks.Task;
@@ -46,15 +42,13 @@
4642
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
4743
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
4844
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
49-
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
5045
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
5146
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
5247

5348
import java.util.ArrayList;
5449
import java.util.Collections;
5550
import java.util.HashMap;
5651
import java.util.LinkedHashMap;
57-
import java.util.LinkedHashSet;
5852
import java.util.List;
5953
import java.util.Map;
6054
import java.util.Set;
@@ -64,29 +58,27 @@
6458

6559
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
6660
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
61+
import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByModelIdsOrAliases;
6762

6863
public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<
6964
GetTrainedModelsStatsAction.Request,
7065
GetTrainedModelsStatsAction.Response> {
7166

7267
private final Client client;
7368
private final ClusterService clusterService;
74-
private final IngestService ingestService;
7569
private final TrainedModelProvider trainedModelProvider;
7670

7771
@Inject
7872
public TransportGetTrainedModelsStatsAction(
7973
TransportService transportService,
8074
ActionFilters actionFilters,
8175
ClusterService clusterService,
82-
IngestService ingestService,
8376
TrainedModelProvider trainedModelProvider,
8477
Client client
8578
) {
8679
super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
8780
this.client = client;
8881
this.clusterService = clusterService;
89-
this.ingestService = ingestService;
9082
this.trainedModelProvider = trainedModelProvider;
9183
}
9284

@@ -133,7 +125,6 @@ protected void doExecute(
133125
.collect(Collectors.toSet());
134126
Map<String, Set<String>> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases(
135127
clusterService.state(),
136-
ingestService,
137128
allPossiblePipelineReferences
138129
);
139130
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(
@@ -261,37 +252,6 @@ static String[] ingestNodes(final ClusterState clusterState) {
261252
return clusterState.nodes().getIngestNodes().keySet().toArray(String[]::new);
262253
}
263254

264-
static Map<String, Set<String>> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set<String> modelIds) {
265-
IngestMetadata ingestMetadata = state.metadata().custom(IngestMetadata.TYPE);
266-
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
267-
if (ingestMetadata == null) {
268-
return pipelineIdsByModelIds;
269-
}
270-
271-
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
272-
try {
273-
Pipeline pipeline = Pipeline.create(
274-
pipelineId,
275-
pipelineConfiguration.getConfigAsMap(),
276-
ingestService.getProcessorFactories(),
277-
ingestService.getScriptService()
278-
);
279-
pipeline.getProcessors().forEach(processor -> {
280-
if (processor instanceof InferenceProcessor inferenceProcessor) {
281-
if (modelIds.contains(inferenceProcessor.getModelId())) {
282-
pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(), m -> new LinkedHashSet<>())
283-
.add(pipelineId);
284-
}
285-
}
286-
});
287-
} catch (Exception ex) {
288-
throw new ElasticsearchException("unexpected failure gathering pipeline information", ex);
289-
}
290-
});
291-
292-
return pipelineIdsByModelIds;
293-
}
294-
295255
static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> pipelineIds) {
296256
IngestStats fullNodeStats = nodeStats.getIngestStats();
297257
Map<String, List<IngestStats.ProcessorStat>> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats());

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

Lines changed: 6 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,13 @@
1414
import org.elasticsearch.action.ActionListener;
1515
import org.elasticsearch.client.internal.Client;
1616
import org.elasticsearch.cluster.ClusterState;
17-
import org.elasticsearch.cluster.metadata.Metadata;
1817
import org.elasticsearch.cluster.service.ClusterService;
1918
import org.elasticsearch.common.settings.Setting;
2019
import org.elasticsearch.common.settings.Settings;
2120
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
2221
import org.elasticsearch.ingest.AbstractProcessor;
2322
import org.elasticsearch.ingest.ConfigurationUtils;
2423
import org.elasticsearch.ingest.IngestDocument;
25-
import org.elasticsearch.ingest.IngestMetadata;
26-
import org.elasticsearch.ingest.Pipeline;
27-
import org.elasticsearch.ingest.PipelineConfiguration;
2824
import org.elasticsearch.ingest.Processor;
2925
import org.elasticsearch.rest.RestStatus;
3026
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
@@ -55,6 +51,7 @@
5551
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
5652
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
5753
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
54+
import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor;
5855

5956
import java.util.Collections;
6057
import java.util.HashMap;
@@ -65,7 +62,6 @@
6562
import java.util.function.Consumer;
6663

6764
import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY;
68-
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;
6965
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
7066
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
7167
import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD;
@@ -192,9 +188,6 @@ public String getType() {
192188

193189
public static final class Factory implements Processor.Factory, Consumer<ClusterState> {
194190

195-
private static final String FOREACH_PROCESSOR_NAME = "foreach";
196-
// Any more than 10 nestings of processors, we stop searching for inference processor definitions
197-
private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10;
198191
private static final Logger logger = LogManager.getLogger(Factory.class);
199192

200193
private final Client client;
@@ -213,86 +206,12 @@ public Factory(Client client, ClusterService clusterService, Settings settings)
213206
@Override
214207
public void accept(ClusterState state) {
215208
minNodeVersion = state.nodes().getMinNodeVersion();
216-
currentInferenceProcessors = countNumberInferenceProcessors(state);
217-
}
218-
219-
public static int countNumberInferenceProcessors(ClusterState state) {
220-
Metadata metadata = state.getMetadata();
221-
if (metadata == null) {
222-
return 0;
223-
}
224-
IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE);
225-
if (ingestMetadata == null) {
226-
return 0;
227-
}
228-
229-
int count = 0;
230-
for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) {
231-
Map<String, Object> configMap = configuration.getConfigAsMap();
232-
try {
233-
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
234-
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
235-
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
236-
count += numInferenceProcessors(entry.getKey(), entry.getValue());
237-
}
238-
}
239-
// We cannot throw any exception here. It might break other pipelines.
240-
} catch (Exception ex) {
241-
logger.debug(() -> "failed gathering processors for pipeline [" + configuration.getId() + "]", ex);
242-
}
209+
try {
210+
currentInferenceProcessors = InferenceProcessorInfoExtractor.countInferenceProcessors(state);
211+
} catch (Exception ex) {
212+
// We cannot throw any exception here. It might break other pipelines.
213+
logger.debug("failed gathering processors for pipelines", ex);
243214
}
244-
return count;
245-
}
246-
247-
@SuppressWarnings("unchecked")
248-
static int numInferenceProcessors(String processorType, Object processorDefinition) {
249-
return numInferenceProcessors(processorType, (Map<String, Object>) processorDefinition, 0);
250-
}
251-
252-
@SuppressWarnings("unchecked")
253-
static int numInferenceProcessors(String processorType, Map<String, Object> processorDefinition, int level) {
254-
int count = 0;
255-
// arbitrary, but we must limit this somehow
256-
if (level > MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS) {
257-
return count;
258-
}
259-
if (processorType == null || processorDefinition == null) {
260-
return count;
261-
}
262-
if (TYPE.equals(processorType)) {
263-
count++;
264-
}
265-
if (FOREACH_PROCESSOR_NAME.equals(processorType)) {
266-
Map<String, Object> innerProcessor = (Map<String, Object>) processorDefinition.get("processor");
267-
if (innerProcessor != null) {
268-
// a foreach processor should only have a SINGLE nested processor. Iteration is for simplicity's sake.
269-
for (Map.Entry<String, Object> innerProcessorWithName : innerProcessor.entrySet()) {
270-
count += numInferenceProcessors(
271-
innerProcessorWithName.getKey(),
272-
(Map<String, Object>) innerProcessorWithName.getValue(),
273-
level + 1
274-
);
275-
}
276-
}
277-
}
278-
if (processorDefinition.containsKey(Pipeline.ON_FAILURE_KEY)) {
279-
List<Map<String, Object>> onFailureConfigs = ConfigurationUtils.readList(
280-
null,
281-
null,
282-
processorDefinition,
283-
Pipeline.ON_FAILURE_KEY
284-
);
285-
count += onFailureConfigs.stream()
286-
.flatMap(map -> map.entrySet().stream())
287-
.mapToInt(entry -> numInferenceProcessors(entry.getKey(), (Map<String, Object>) entry.getValue(), level + 1))
288-
.sum();
289-
}
290-
return count;
291-
}
292-
293-
// Used for testing
294-
int numInferenceProcessors() {
295-
return currentInferenceProcessors;
296215
}
297216

298217
@Override

0 commit comments

Comments
 (0)