Skip to content

Commit 2f7e028

Browse files
authored
Enhance profile API to add model centric result controlled by view parameter (#714)
* Enhance profile API to add model centric result controled by view paramter Signed-off-by: Zan Niu <[email protected]> * Enhance profile API to add model centric result controled by view parameter Signed-off-by: Zan Niu <[email protected]> * Enhance profile API to add model centric result controled by view parameter Signed-off-by: Zan Niu <[email protected]> --------- Signed-off-by: Zan Niu <[email protected]>
1 parent 1bf57c5 commit 2f7e028

File tree

6 files changed

+315
-9
lines changed

6 files changed

+315
-9
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
*
3+
* * Copyright OpenSearch Contributors
4+
* * SPDX-License-Identifier: Apache-2.0
5+
*
6+
*/
7+
8+
package org.opensearch.ml.action.profile;
9+
10+
import java.io.IOException;
11+
import java.util.HashMap;
12+
import java.util.Map;
13+
14+
import lombok.Getter;
15+
import lombok.NoArgsConstructor;
16+
import lombok.Setter;
17+
18+
import org.opensearch.common.io.stream.StreamInput;
19+
import org.opensearch.common.io.stream.StreamOutput;
20+
import org.opensearch.common.io.stream.Writeable;
21+
import org.opensearch.common.xcontent.ToXContentFragment;
22+
import org.opensearch.common.xcontent.XContentBuilder;
23+
import org.opensearch.ml.common.MLTask;
24+
import org.opensearch.ml.profile.MLModelProfile;
25+
26+
@Getter
27+
@NoArgsConstructor
28+
public class MLProfileModelResponse implements ToXContentFragment, Writeable {
29+
@Setter
30+
private String[] targetWorkerNodes;
31+
32+
@Setter
33+
private String[] workerNodes;
34+
35+
private Map<String, MLModelProfile> mlModelProfileMap = new HashMap<>();
36+
37+
private Map<String, MLTask> mlTaskMap = new HashMap<>();
38+
39+
public MLProfileModelResponse(String[] targetWorkerNodes, String[] workerNodes) {
40+
this.targetWorkerNodes = targetWorkerNodes;
41+
this.workerNodes = workerNodes;
42+
}
43+
44+
public MLProfileModelResponse(StreamInput in) throws IOException {
45+
this.workerNodes = in.readOptionalStringArray();
46+
this.targetWorkerNodes = in.readOptionalStringArray();
47+
if (in.readBoolean()) {
48+
this.mlModelProfileMap = in.readMap(StreamInput::readString, MLModelProfile::new);
49+
}
50+
if (in.readBoolean()) {
51+
this.mlTaskMap = in.readMap(StreamInput::readString, MLTask::new);
52+
}
53+
}
54+
55+
@Override
56+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
57+
builder.startObject();
58+
if (targetWorkerNodes != null) {
59+
builder.field("target_worker_nodes", targetWorkerNodes);
60+
}
61+
if (workerNodes != null) {
62+
builder.field("worker_nodes", workerNodes);
63+
}
64+
if (mlModelProfileMap.size() > 0) {
65+
builder.startObject("nodes");
66+
for (Map.Entry<String, MLModelProfile> entry : mlModelProfileMap.entrySet()) {
67+
builder.field(entry.getKey(), entry.getValue());
68+
}
69+
builder.endObject();
70+
}
71+
if (mlTaskMap.size() > 0) {
72+
builder.startObject("tasks");
73+
for (Map.Entry<String, MLTask> entry : mlTaskMap.entrySet()) {
74+
builder.field(entry.getKey(), entry.getValue());
75+
}
76+
builder.endObject();
77+
}
78+
builder.endObject();
79+
return builder;
80+
}
81+
82+
@Override
83+
public void writeTo(StreamOutput streamOutput) throws IOException {
84+
streamOutput.writeOptionalStringArray(workerNodes);
85+
streamOutput.writeOptionalStringArray(targetWorkerNodes);
86+
if (mlModelProfileMap.size() > 0) {
87+
streamOutput.writeBoolean(true);
88+
streamOutput.writeMap(mlModelProfileMap, StreamOutput::writeString, (o, r) -> r.writeTo(o));
89+
} else {
90+
streamOutput.writeBoolean(false);
91+
}
92+
if (mlTaskMap.size() > 0) {
93+
streamOutput.writeBoolean(true);
94+
streamOutput.writeMap(mlTaskMap, StreamOutput::writeString, (o, r) -> r.writeTo(o));
95+
} else {
96+
streamOutput.writeBoolean(false);
97+
}
98+
99+
}
100+
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLProfileAction.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import java.io.IOException;
1717
import java.util.Arrays;
18+
import java.util.HashMap;
1819
import java.util.List;
20+
import java.util.Map;
1921
import java.util.Optional;
2022
import java.util.stream.Collectors;
2123

@@ -28,20 +30,29 @@
2830
import org.opensearch.common.xcontent.XContentBuilder;
2931
import org.opensearch.common.xcontent.XContentParser;
3032
import org.opensearch.ml.action.profile.MLProfileAction;
33+
import org.opensearch.ml.action.profile.MLProfileModelResponse;
3134
import org.opensearch.ml.action.profile.MLProfileNodeResponse;
3235
import org.opensearch.ml.action.profile.MLProfileRequest;
36+
import org.opensearch.ml.common.MLTask;
37+
import org.opensearch.ml.profile.MLModelProfile;
3338
import org.opensearch.ml.profile.MLProfileInput;
39+
import org.opensearch.ml.utils.RestActionUtils;
3440
import org.opensearch.rest.BaseRestHandler;
3541
import org.opensearch.rest.BytesRestResponse;
3642
import org.opensearch.rest.RestRequest;
3743
import org.opensearch.rest.RestStatus;
3844

3945
import com.google.common.collect.ImmutableList;
46+
import com.google.common.collect.ImmutableMap;
4047

4148
@Log4j2
4249
public class RestMLProfileAction extends BaseRestHandler {
4350
private static final String PROFILE_ML_ACTION = "profile_ml";
4451

52+
private static final String VIEW = "view";
53+
private static final String MODEL_VIEW = "model";
54+
private static final String NODE_VIEW = "node";
55+
4556
private ClusterService clusterService;
4657

4758
/**
@@ -80,6 +91,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
8091
} else {
8192
mlProfileInput = createMLProfileInputFromRequestParams(request);
8293
}
94+
String view = RestActionUtils.getStringParam(request, VIEW).orElse(NODE_VIEW);
8395
String[] nodeIds = mlProfileInput.retrieveProfileOnAllNodes()
8496
? getAllNodes(clusterService)
8597
: mlProfileInput.getNodeIds().toArray(new String[0]);
@@ -93,7 +105,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
93105
List<MLProfileNodeResponse> nodeProfiles = r.getNodes().stream().filter(s -> !s.isEmpty()).collect(Collectors.toList());
94106
log.debug("Build MLProfileNodeResponse for size of {}", nodeProfiles.size());
95107
if (nodeProfiles.size() > 0) {
96-
r.toXContent(builder, ToXContent.EMPTY_PARAMS);
108+
if (NODE_VIEW.equals(view)) {
109+
r.toXContent(builder, ToXContent.EMPTY_PARAMS);
110+
} else if (MODEL_VIEW.equals(view)) {
111+
Map<String, MLProfileModelResponse> modelCentricProfileMap = buildModelCentricResult(nodeProfiles);
112+
builder.startObject("models");
113+
for (Map.Entry<String, MLProfileModelResponse> entry : modelCentricProfileMap.entrySet()) {
114+
builder.field(entry.getKey(), entry.getValue());
115+
}
116+
builder.endObject();
117+
}
97118
}
98119
builder.endObject();
99120
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
@@ -105,6 +126,59 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
105126
};
106127
}
107128

129+
/**
130+
* The data structure for node centric is:
131+
* MLProfileNodeResponse:
132+
* taskMap: Map<String, MLTask>
133+
* modelMap: Map<String, MLModelProfile> model_id, MLModelProfile
134+
* And we need to convert to format like this:
135+
* modelMap: Map<String, Map<String, MLModelProfile>>
136+
*/
137+
private Map<String, MLProfileModelResponse> buildModelCentricResult(List<MLProfileNodeResponse> nodeResponses) {
138+
// aggregate model information into one final map.
139+
Map<String, MLProfileModelResponse> modelCentricMap = new HashMap<>();
140+
for (MLProfileNodeResponse mlProfileNodeResponse : nodeResponses) {
141+
String nodeId = mlProfileNodeResponse.getNode().getId();
142+
Map<String, MLModelProfile> modelProfileMap = mlProfileNodeResponse.getMlNodeModels();
143+
Map<String, MLTask> taskProfileMap = mlProfileNodeResponse.getMlNodeTasks();
144+
for (Map.Entry<String, MLModelProfile> entry : modelProfileMap.entrySet()) {
145+
MLProfileModelResponse mlProfileModelResponse = modelCentricMap.get(entry.getKey());
146+
if (mlProfileModelResponse == null) {
147+
mlProfileModelResponse = new MLProfileModelResponse(
148+
entry.getValue().getTargetWorkerNodes(),
149+
entry.getValue().getWorkerNodes()
150+
);
151+
modelCentricMap.put(entry.getKey(), mlProfileModelResponse);
152+
}
153+
if (mlProfileModelResponse.getTargetWorkerNodes() == null || mlProfileModelResponse.getWorkerNodes() == null) {
154+
mlProfileModelResponse.setTargetWorkerNodes(entry.getValue().getTargetWorkerNodes());
155+
mlProfileModelResponse.setWorkerNodes(entry.getValue().getWorkerNodes());
156+
}
157+
// Create a new object and remove targetWorkerNodes and workerNodes.
158+
MLModelProfile modelProfile = new MLModelProfile(
159+
entry.getValue().getModelState(),
160+
entry.getValue().getPredictor(),
161+
null,
162+
null,
163+
entry.getValue().getModelInferenceStats(),
164+
entry.getValue().getPredictRequestStats()
165+
);
166+
mlProfileModelResponse.getMlModelProfileMap().putAll(ImmutableMap.of(nodeId, modelProfile));
167+
}
168+
169+
for (Map.Entry<String, MLTask> entry : taskProfileMap.entrySet()) {
170+
String modelId = entry.getValue().getModelId();
171+
MLProfileModelResponse mlProfileModelResponse = modelCentricMap.get(modelId);
172+
if (mlProfileModelResponse == null) {
173+
mlProfileModelResponse = new MLProfileModelResponse();
174+
modelCentricMap.put(modelId, mlProfileModelResponse);
175+
}
176+
mlProfileModelResponse.getMlTaskMap().putAll(ImmutableMap.of(entry.getKey(), entry.getValue()));
177+
}
178+
}
179+
return modelCentricMap;
180+
}
181+
108182
MLProfileInput createMLProfileInputFromRequestParams(RestRequest request) {
109183
MLProfileInput mlProfileInput = new MLProfileInput();
110184
Optional<String[]> modelIds = splitCommaSeparatedParam(request, PARAMETER_MODEL_ID);

plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,8 @@ public static Optional<String[]> splitCommaSeparatedParam(RestRequest request, S
161161
return Optional.ofNullable(request.param(paramName)).map(s -> s.split(","));
162162
}
163163

164+
public static Optional<String> getStringParam(RestRequest request, String paramName) {
165+
return Optional.ofNullable(request.param(paramName));
166+
}
167+
164168
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
*
3+
* * Copyright OpenSearch Contributors
4+
* * SPDX-License-Identifier: Apache-2.0
5+
*
6+
*/
7+
8+
package org.opensearch.ml.action.profile;
9+
10+
import java.io.IOException;
11+
import java.time.Instant;
12+
import java.util.Arrays;
13+
import java.util.HashMap;
14+
import java.util.Map;
15+
16+
import org.junit.Before;
17+
import org.opensearch.common.io.stream.BytesStreamOutput;
18+
import org.opensearch.common.xcontent.ToXContent;
19+
import org.opensearch.common.xcontent.XContentBuilder;
20+
import org.opensearch.common.xcontent.XContentType;
21+
import org.opensearch.commons.authuser.User;
22+
import org.opensearch.ml.common.FunctionName;
23+
import org.opensearch.ml.common.MLTask;
24+
import org.opensearch.ml.common.MLTaskState;
25+
import org.opensearch.ml.common.MLTaskType;
26+
import org.opensearch.ml.common.dataset.MLInputDataType;
27+
import org.opensearch.ml.common.model.MLModelState;
28+
import org.opensearch.ml.profile.MLModelProfile;
29+
import org.opensearch.ml.profile.MLPredictRequestStats;
30+
import org.opensearch.ml.utils.TestHelper;
31+
import org.opensearch.test.OpenSearchTestCase;
32+
33+
public class MLProfileModelResponseTests extends OpenSearchTestCase {
34+
35+
MLTask mlTask;
36+
MLModelProfile mlModelProfile;
37+
38+
@Before
39+
public void setup() {
40+
mlTask = MLTask
41+
.builder()
42+
.taskId("test_id")
43+
.modelId("model_id")
44+
.taskType(MLTaskType.TRAINING)
45+
.functionName(FunctionName.AD_LIBSVM)
46+
.state(MLTaskState.CREATED)
47+
.inputType(MLInputDataType.DATA_FRAME)
48+
.progress(0.4f)
49+
.outputIndex("test_index")
50+
.workerNodes(Arrays.asList("test_node"))
51+
.createTime(Instant.ofEpochMilli(123))
52+
.lastUpdateTime(Instant.ofEpochMilli(123))
53+
.error("error")
54+
.user(new User())
55+
.async(false)
56+
.build();
57+
mlModelProfile = MLModelProfile
58+
.builder()
59+
.predictor("test_predictor")
60+
.workerNodes(new String[] { "node1", "node2" })
61+
.modelState(MLModelState.LOADED)
62+
.modelInferenceStats(MLPredictRequestStats.builder().count(10L).average(11.0).max(20.0).min(5.0).build())
63+
.build();
64+
}
65+
66+
public void test_create_MLProfileModelResponse_withArgs() throws IOException {
67+
String[] targetWorkerNodes = new String[] { "node1", "node2" };
68+
String[] workerNodes = new String[] { "node1" };
69+
Map<String, MLModelProfile> profileMap = new HashMap<>();
70+
Map<String, MLTask> taskMap = new HashMap<>();
71+
profileMap.put("node1", mlModelProfile);
72+
taskMap.put("node1", mlTask);
73+
MLProfileModelResponse response = new MLProfileModelResponse(targetWorkerNodes, workerNodes);
74+
response.getMlModelProfileMap().putAll(profileMap);
75+
response.getMlTaskMap().putAll(taskMap);
76+
BytesStreamOutput output = new BytesStreamOutput();
77+
response.writeTo(output);
78+
MLProfileModelResponse newResponse = new MLProfileModelResponse(output.bytes().streamInput());
79+
assertNotNull(newResponse.getTargetWorkerNodes());
80+
assertNotNull(response.getTargetWorkerNodes());
81+
assertEquals(newResponse.getTargetWorkerNodes().length, response.getTargetWorkerNodes().length);
82+
assertEquals(newResponse.getMlModelProfileMap().size(), response.getMlModelProfileMap().size());
83+
assertEquals(newResponse.getMlTaskMap().size(), response.getMlTaskMap().size());
84+
}
85+
86+
public void test_create_MLProfileModelResponse_NoArgs() throws IOException {
87+
MLProfileModelResponse response = new MLProfileModelResponse();
88+
BytesStreamOutput output = new BytesStreamOutput();
89+
response.writeTo(output);
90+
MLProfileModelResponse newResponse = new MLProfileModelResponse(output.bytes().streamInput());
91+
assertNull(response.getWorkerNodes());
92+
assertNull(newResponse.getWorkerNodes());
93+
}
94+
95+
public void test_toXContent() throws IOException {
96+
String[] targetWorkerNodes = new String[] { "node1", "node2" };
97+
String[] workerNodes = new String[] { "node1" };
98+
Map<String, MLModelProfile> profileMap = new HashMap<>();
99+
Map<String, MLTask> taskMap = new HashMap<>();
100+
profileMap.put("node1", mlModelProfile);
101+
taskMap.put("node1", mlTask);
102+
MLProfileModelResponse response = new MLProfileModelResponse(targetWorkerNodes, workerNodes);
103+
response.getMlModelProfileMap().putAll(profileMap);
104+
response.getMlTaskMap().putAll(taskMap);
105+
106+
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
107+
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
108+
String xContentString = TestHelper.xContentBuilderToString(builder);
109+
System.out.println(xContentString);
110+
}
111+
112+
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import static org.mockito.Mockito.times;
1313
import static org.mockito.Mockito.verify;
1414
import static org.mockito.Mockito.when;
15-
import static org.opensearch.ml.utils.TestHelper.getProfileRestRequest;
16-
import static org.opensearch.ml.utils.TestHelper.setupTestClusterState;
15+
import static org.opensearch.ml.utils.TestHelper.*;
1716

1817
import java.io.IOException;
1918
import java.time.Instant;
@@ -68,6 +67,7 @@
6867
import org.opensearch.threadpool.ThreadPool;
6968

7069
import com.google.common.collect.ImmutableList;
70+
import com.google.common.collect.ImmutableMap;
7171

7272
public class RestMLProfileActionTests extends OpenSearchTestCase {
7373
@Rule
@@ -286,6 +286,14 @@ public void test_PrepareRequest_Failure() throws Exception {
286286
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
287287
}
288288

289+
public void test_WhenViewIsModel_ReturnModelViewResult() throws Exception {
290+
MLProfileInput mlProfileInput = new MLProfileInput();
291+
RestRequest request = getProfileRestRequestWithQueryParams(mlProfileInput, ImmutableMap.of("view", "model"));
292+
profileAction.handleRequest(request, channel, client);
293+
ArgumentCaptor<MLProfileRequest> argumentCaptor = ArgumentCaptor.forClass(MLProfileRequest.class);
294+
verify(client, times(1)).execute(eq(MLProfileAction.INSTANCE), argumentCaptor.capture(), any());
295+
}
296+
289297
private RestRequest getRestRequest() {
290298
Map<String, String> params = new HashMap<>();
291299
params.put("task_id", "test_id");

0 commit comments

Comments
 (0)