Skip to content

Commit 0f5023e

Browse files
authored
add more stats: request/failure/model count on algo/action level (#159)
* add more stats: request/failure/model count on algo/action level Signed-off-by: Yaliang Wu <[email protected]> * fix missing stats when get all stats Signed-off-by: Yaliang Wu <[email protected]>
1 parent e55fffa commit 0f5023e

19 files changed

+310
-157
lines changed

plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesRequest.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@ public class MLStatsNodesRequest extends BaseNodesRequest<MLStatsNodesRequest> {
2424

2525
@Getter
2626
private Set<String> statsToBeRetrieved;
27+
/**
28+
* If set this field as true, will retrieve all stats.
29+
*/
30+
private boolean retrieveAllStats = false;
2731

2832
public MLStatsNodesRequest(StreamInput in) throws IOException {
2933
super(in);
34+
retrieveAllStats = in.readBoolean();
3035
statsToBeRetrieved = in.readSet(StreamInput::readString);
3136
}
3237

@@ -50,6 +55,14 @@ public MLStatsNodesRequest(DiscoveryNode... nodes) {
5055
statsToBeRetrieved = new HashSet<>();
5156
}
5257

58+
public boolean isRetrieveAllStats() {
59+
return retrieveAllStats;
60+
}
61+
62+
public void setRetrieveAllStats(boolean retrieveAllStats) {
63+
this.retrieveAllStats = retrieveAllStats;
64+
}
65+
5366
/**
5467
* Adds a stat to the set of stats to be retrieved
5568
*
@@ -82,6 +95,7 @@ public void readFrom(StreamInput in) throws IOException {
8295
@Override
8396
public void writeTo(StreamOutput out) throws IOException {
8497
super.writeTo(out);
98+
out.writeBoolean(retrieveAllStats);
8599
out.writeStringCollection(statsToBeRetrieved);
86100
}
87101
}

plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,15 @@ protected MLStatsNodeResponse nodeOperation(MLStatsNodeRequest request) {
9090
private MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRequest) {
9191
Map<String, Object> statValues = new HashMap<>();
9292
Set<String> statsToBeRetrieved = mlStatsNodesRequest.getStatsToBeRetrieved();
93+
boolean retrieveAllStats = mlStatsNodesRequest.isRetrieveAllStats();
9394

94-
if (statsToBeRetrieved.contains(InternalStatNames.JVM_HEAP_USAGE.getName())) {
95+
if (retrieveAllStats || statsToBeRetrieved.contains(InternalStatNames.JVM_HEAP_USAGE.getName())) {
9596
long heapUsedPercent = jvmService.stats().getMem().getHeapUsedPercent();
9697
statValues.put(InternalStatNames.JVM_HEAP_USAGE.getName(), heapUsedPercent);
9798
}
9899

99100
for (String statName : mlStats.getNodeStats().keySet()) {
100-
if (statsToBeRetrieved.contains(statName)) {
101+
if (retrieveAllStats || statsToBeRetrieved.contains(statName)) {
101102
statValues.put(statName, mlStats.getStats().get(statName).getValue());
102103
}
103104
}

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Collections;
1010
import java.util.List;
1111
import java.util.Map;
12+
import java.util.concurrent.ConcurrentHashMap;
1213
import java.util.function.Supplier;
1314

1415
import org.opensearch.action.ActionRequest;
@@ -89,7 +90,6 @@
8990
import org.opensearch.watcher.ResourceWatcherService;
9091

9192
import com.google.common.collect.ImmutableList;
92-
import com.google.common.collect.ImmutableMap;
9393

9494
public class MachineLearningPlugin extends Plugin implements ActionPlugin {
9595
public static final String TASK_THREAD_POOL = "OPENSEARCH_ML_TASK_THREAD_POOL";
@@ -157,10 +157,11 @@ public Collection<Object> createComponents(
157157
JvmService jvmService = new JvmService(environment.settings());
158158
MLCircuitBreakerService mlCircuitBreakerService = new MLCircuitBreakerService(jvmService).init();
159159

160-
Map<String, MLStat<?>> stats = ImmutableMap
161-
.<String, MLStat<?>>builder()
162-
.put(StatNames.ML_EXECUTING_TASK_COUNT.getName(), new MLStat<>(false, new CounterSupplier()))
163-
.build();
160+
Map<String, MLStat<?>> stats = new ConcurrentHashMap<>();
161+
stats.put(StatNames.ML_EXECUTING_TASK_COUNT, new MLStat<>(false, new CounterSupplier()));
162+
stats.put(StatNames.ML_TOTAL_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier()));
163+
stats.put(StatNames.ML_TOTAL_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier()));
164+
stats.put(StatNames.ML_TOTAL_MODEL_COUNT, new MLStat<>(false, new CounterSupplier()));
164165
this.mlStats = new MLStats(stats);
165166

166167
mlIndicesHandler = new MLIndicesHandler(clusterService, client);

plugin/src/main/java/org/opensearch/ml/rest/RestStatsMLAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ MLStatsNodesRequest getRequest(RestRequest request) {
7878

7979
Set<String> validStats = mlStats.getStats().keySet();
8080
if (isAllStatsRequested(requestedStats)) {
81-
mlStatsRequest.addAll(validStats);
81+
mlStatsRequest.setRetrieveAllStats(true);
8282
} else {
8383
mlStatsRequest.addAll(getStatsToBeRetrieved(request, validStats, requestedStats));
8484
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.stats;
7+
8+
public enum ActionName {
9+
TRAIN,
10+
PREDICT,
11+
TRAIN_PREDICT;
12+
}

plugin/src/main/java/org/opensearch/ml/stats/MLStats.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
import java.util.HashMap;
99
import java.util.Map;
10+
import java.util.function.Supplier;
1011

1112
import lombok.Getter;
1213

14+
import org.opensearch.ml.stats.suppliers.CounterSupplier;
15+
1316
/**
1417
* This class is the main entry-point for access to the stats that the ML plugin keeps track of.
1518
*/
@@ -40,6 +43,25 @@ public MLStat<?> getStat(String key) throws IllegalArgumentException {
4043
return stats.get(key);
4144
}
4245

46+
/**
47+
* Get stat or create counter stat if absent.
48+
* @param key stat key
49+
* @return existing MLStat or new MLStat
50+
*/
51+
public MLStat<?> createCounterStatIfAbsent(String key) {
52+
return createStatIfAbsent(key, () -> new MLStat<>(false, new CounterSupplier()));
53+
}
54+
55+
/**
56+
* Get stat or create if absent.
57+
* @param key stat key
58+
* @param supplier supplier to create MLStat
59+
* @return existing MLStat or new MLStat
60+
*/
61+
public synchronized MLStat<?> createStatIfAbsent(String key, Supplier<MLStat> supplier) {
62+
return stats.computeIfAbsent(key, k -> supplier.get());
63+
}
64+
4365
/**
4466
* Get a map of the stats that are kept at the node level
4567
*

plugin/src/main/java/org/opensearch/ml/stats/StatNames.java

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,32 @@
55

66
package org.opensearch.ml.stats;
77

8-
import java.util.HashSet;
9-
import java.util.Set;
8+
import java.util.Locale;
109

11-
import lombok.Getter;
10+
import org.opensearch.ml.common.parameter.FunctionName;
1211

1312
/**
1413
* Enum containing names of all stats
1514
*/
16-
public enum StatNames {
17-
ML_EXECUTING_TASK_COUNT("ml_executing_task_count");
15+
public class StatNames {
16+
public static String ML_EXECUTING_TASK_COUNT = "ml_executing_task_count";
17+
public static String ML_TOTAL_REQUEST_COUNT = "ml_total_request_count";
18+
public static String ML_TOTAL_FAILURE_COUNT = "ml_total_failure_count";
19+
public static String ML_TOTAL_MODEL_COUNT = "ml_total_model_count";
20+
21+
public static String requestCountStat(FunctionName functionName, ActionName actionName) {
22+
return String.format("ml_%s_%s_request_count", functionName, actionName, Locale.ROOT).toLowerCase(Locale.ROOT);
23+
}
1824

19-
@Getter
20-
private String name;
25+
public static String failureCountStat(FunctionName functionName, ActionName actionName) {
26+
return String.format("ml_%s_%s_failure_count", functionName, actionName, Locale.ROOT).toLowerCase(Locale.ROOT);
27+
}
2128

22-
StatNames(String name) {
23-
this.name = name;
29+
public static String executingRequestCountStat(FunctionName functionName, ActionName actionName) {
30+
return String.format("ml_%s_%s_executing_request_count", functionName, actionName, Locale.ROOT).toLowerCase(Locale.ROOT);
2431
}
2532

26-
/**
27-
* Get set of stat names
28-
*
29-
* @return set of stat names
30-
*/
31-
public static Set<String> getNames() {
32-
Set<String> names = new HashSet<>();
33-
34-
for (StatNames statName : StatNames.values()) {
35-
names.add(statName.getName());
36-
}
37-
return names;
33+
public static String modelCountStat(FunctionName functionName) {
34+
return String.format("ml_%s_model_count", functionName, Locale.ROOT).toLowerCase(Locale.ROOT);
3835
}
3936
}

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

Lines changed: 64 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import static org.opensearch.ml.permission.AccessController.getUserContext;
1111
import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL;
1212
import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT;
13+
import static org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT;
14+
import static org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT;
15+
import static org.opensearch.ml.stats.StatNames.failureCountStat;
16+
import static org.opensearch.ml.stats.StatNames.requestCountStat;
1317

1418
import java.time.Instant;
1519
import java.util.Base64;
@@ -46,6 +50,7 @@
4650
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
4751
import org.opensearch.ml.engine.MLEngine;
4852
import org.opensearch.ml.indices.MLInputDatasetHandler;
53+
import org.opensearch.ml.stats.ActionName;
4954
import org.opensearch.ml.stats.MLStats;
5055
import org.opensearch.threadpool.ThreadPool;
5156
import org.opensearch.transport.TransportService;
@@ -146,86 +151,74 @@ private void predict(
146151
) {
147152
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
148153
// track ML task count and add ML task into cache
149-
mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment();
154+
mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
155+
mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
156+
mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment();
150157
mlTaskManager.add(mlTask);
151-
MLInput mlInput = request.getMlInput();
152158

153159
// run predict
154-
try {
160+
if (request.getModelId() != null) {
155161
// search model by model id.
156-
Model model = new Model();
157-
if (request.getModelId() != null) {
162+
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
163+
MLInput mlInput = request.getMlInput();
164+
ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> {
165+
if (r == null || !r.isExists()) {
166+
internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
167+
return;
168+
}
169+
Map<String, Object> source = r.getSourceAsMap();
170+
User requestUser = getUserContext(client);
171+
User resourceUser = User.parse((String) source.get(USER));
172+
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
173+
// The backend roles of request user and resource user doesn't have intersection
174+
OpenSearchException e = new OpenSearchException(
175+
"User: " + requestUser.getName() + " does not have permissions to run predict by model: " + request.getModelId()
176+
);
177+
log.debug(e);
178+
handlePredictFailure(mlTask, internalListener, e, false);
179+
return;
180+
}
181+
182+
Model model = new Model();
183+
model.setName((String) source.get(MLModel.MODEL_NAME));
184+
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
185+
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
186+
model.setContent(decoded);
187+
188+
// run predict
189+
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
190+
MLOutput output = MLEngine
191+
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
192+
if (output instanceof MLPredictionOutput) {
193+
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
194+
}
195+
196+
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
197+
handleAsyncMLTaskComplete(mlTask);
198+
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
199+
internalListener.onResponse(response);
200+
}, e -> {
201+
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + mlTask.getModelId(), e);
202+
handlePredictFailure(mlTask, internalListener, e, true);
203+
});
158204
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
159-
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
160-
ActionListener<GetResponse> getResponseListener = ActionListener.wrap(r -> {
161-
if (r == null || !r.isExists()) {
162-
internalListener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
163-
return;
164-
}
165-
Map<String, Object> source = r.getSourceAsMap();
166-
User requestUser = getUserContext(client);
167-
User resourceUser = User.parse((String) source.get(USER));
168-
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
169-
// The backend roles of request user and resource user doesn't have intersection
170-
OpenSearchException e = new OpenSearchException(
171-
"User: "
172-
+ requestUser.getName()
173-
+ " does not have permissions to run predict by model: "
174-
+ request.getModelId()
175-
);
176-
log.debug(e);
177-
handlePredictFailure(mlTask, internalListener, e);
178-
return;
179-
}
180-
181-
model.setName((String) source.get(MLModel.MODEL_NAME));
182-
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
183-
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
184-
model.setContent(decoded);
185-
186-
// run predict
187-
MLOutput output;
188-
try {
189-
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
190-
output = MLEngine
191-
.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
192-
if (output instanceof MLPredictionOutput) {
193-
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
194-
}
195-
196-
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
197-
handleAsyncMLTaskComplete(mlTask);
198-
} catch (Exception e) {
199-
// todo need to specify what exception
200-
log.error("Failed to predict " + mlInput.getAlgorithm() + ", modelId: " + model.getName(), e);
201-
handlePredictFailure(mlTask, internalListener, e);
202-
return;
203-
}
204-
205-
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
206-
internalListener.onResponse(response);
207-
}, e -> {
208-
log.error("Failed to predict model " + mlTask.getModelId(), e);
209-
internalListener.onFailure(e);
210-
});
211-
client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore()));
212-
} catch (Exception e) {
213-
log.error("Failed to get model " + mlTask.getModelId(), e);
214-
internalListener.onFailure(e);
215-
}
216-
} else {
217-
IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
218-
log.error("ModelId is invalid", e);
219-
handlePredictFailure(mlTask, internalListener, e);
220-
return;
205+
client.get(getRequest, ActionListener.runBefore(getResponseListener, () -> context.restore()));
206+
} catch (Exception e) {
207+
log.error("Failed to get model " + mlTask.getModelId(), e);
208+
handlePredictFailure(mlTask, internalListener, e, true);
221209
}
222-
} catch (Exception e) {
223-
log.error("Failed to predict " + mlInput.getAlgorithm(), e);
224-
internalListener.onFailure(e);
210+
} else {
211+
IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
212+
log.error("ModelId is invalid", e);
213+
handlePredictFailure(mlTask, internalListener, e, false);
225214
}
226215
}
227216

228-
private void handlePredictFailure(MLTask mlTask, ActionListener<MLTaskResponse> listener, Exception e) {
217+
private void handlePredictFailure(MLTask mlTask, ActionListener<MLTaskResponse> listener, Exception e, boolean trackFailure) {
218+
if (trackFailure) {
219+
mlStats.createCounterStatIfAbsent(failureCountStat(mlTask.getFunctionName(), ActionName.PREDICT)).increment();
220+
mlStats.getStat(ML_TOTAL_FAILURE_COUNT).increment();
221+
}
229222
handleAsyncMLTaskFailure(mlTask, e);
230223
listener.onFailure(e);
231224
}

plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public void dispatchTask(ActionListener<DiscoveryNode> listener) {
5656
// DiscoveryNode[] mlNodes = getEligibleMLNodes();
5757
DiscoveryNode[] mlNodes = getEligibleDataNodes();
5858
MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes);
59-
MLStatsNodesRequest.addAll(ImmutableSet.of(ML_EXECUTING_TASK_COUNT.getName(), JVM_HEAP_USAGE.getName()));
59+
MLStatsNodesRequest.addAll(ImmutableSet.of(ML_EXECUTING_TASK_COUNT, JVM_HEAP_USAGE.getName()));
6060

6161
client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> {
6262
// Check JVM pressure
@@ -78,7 +78,7 @@ public void dispatchTask(ActionListener<DiscoveryNode> listener) {
7878
// Check # of executing ML task
7979
candidateNodeResponse = candidateNodeResponse
8080
.stream()
81-
.filter(stat -> (Long) stat.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()) < maxMLBatchTaskPerNode)
81+
.filter(stat -> (Long) stat.getStatsMap().get(ML_EXECUTING_TASK_COUNT) < maxMLBatchTaskPerNode)
8282
.collect(Collectors.toList());
8383
if (candidateNodeResponse.size() == 0) {
8484
String errorMessage = "All nodes' executing ML task count reach limitation.";
@@ -91,8 +91,8 @@ public void dispatchTask(ActionListener<DiscoveryNode> listener) {
9191
Optional<MLStatsNodeResponse> targetNode = candidateNodeResponse
9292
.stream()
9393
.sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> {
94-
int result = ((Long) r1.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()))
95-
.compareTo((Long) r2.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()));
94+
int result = ((Long) r1.getStatsMap().get(ML_EXECUTING_TASK_COUNT))
95+
.compareTo((Long) r2.getStatsMap().get(ML_EXECUTING_TASK_COUNT));
9696
if (result == 0) {
9797
// if multiple nodes have same running task count, choose the one with least
9898
// JVM heap usage.

0 commit comments

Comments
 (0)