Skip to content

Commit 46a8d6a

Browse files
authored
add planning work nodes to model (#715)
* add planning work nodes to model Signed-off-by: Yaliang Wu <[email protected]> * add test Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]>
1 parent 2f7e028 commit 46a8d6a

File tree

4 files changed

+44
-3
lines changed

4 files changed

+44
-3
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ public class CommonValue {
9191
+ MLModel.CURRENT_WORKER_NODE_COUNT_FIELD
9292
+ "\" : {\"type\": \"integer\"},\n"
9393
+ " \""
94+
+ MLModel.PLANNING_WORKER_NODES_FIELD
95+
+ "\": {\"type\": \"keyword\"},\n"
96+
+ " \""
9497
+ MLModel.MODEL_CONFIG_FIELD
9598
+ "\" : {\"properties\":{\""
9699
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""

common/src/main/java/org/opensearch/ml/common/MLModel.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import java.io.IOException;
2323
import java.time.Instant;
24+
import java.util.ArrayList;
25+
import java.util.List;
2426

2527
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
2628
import static org.opensearch.ml.common.CommonValue.USER;
@@ -54,6 +56,7 @@ public class MLModel implements ToXContentObject {
5456
public static final String TOTAL_CHUNKS_FIELD = "total_chunks";
5557
public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count";
5658
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
59+
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";
5760

5861
private String name;
5962
private FunctionName algorithm;
@@ -81,6 +84,7 @@ public class MLModel implements ToXContentObject {
8184
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
8285
private Integer currentWorkerNodeCount; // model is deployed to how many nodes
8386

87+
private String[] planningWorkerNodes; // plan to deploy model to these nodes
8488
@Builder(toBuilder = true)
8589
public MLModel(String name,
8690
FunctionName algorithm,
@@ -101,7 +105,8 @@ public MLModel(String name,
101105
String modelId, Integer chunkNumber,
102106
Integer totalChunks,
103107
Integer planningWorkerNodeCount,
104-
Integer currentWorkerNodeCount) {
108+
Integer currentWorkerNodeCount,
109+
String[] planningWorkerNodes) {
105110
this.name = name;
106111
this.algorithm = algorithm;
107112
this.version = version;
@@ -123,6 +128,7 @@ public MLModel(String name,
123128
this.totalChunks = totalChunks;
124129
this.planningWorkerNodeCount = planningWorkerNodeCount;
125130
this.currentWorkerNodeCount = currentWorkerNodeCount;
131+
this.planningWorkerNodes = planningWorkerNodes;
126132
}
127133

128134
public MLModel(StreamInput input) throws IOException{
@@ -158,6 +164,7 @@ public MLModel(StreamInput input) throws IOException{
158164
totalChunks = input.readOptionalInt();
159165
planningWorkerNodeCount = input.readOptionalInt();
160166
currentWorkerNodeCount = input.readOptionalInt();
167+
planningWorkerNodes = input.readOptionalStringArray();
161168
}
162169
}
163170

@@ -203,6 +210,7 @@ public void writeTo(StreamOutput out) throws IOException {
203210
out.writeOptionalInt(totalChunks);
204211
out.writeOptionalInt(planningWorkerNodeCount);
205212
out.writeOptionalInt(currentWorkerNodeCount);
213+
out.writeOptionalStringArray(planningWorkerNodes);
206214
}
207215

208216
@Override
@@ -271,6 +279,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
271279
if (currentWorkerNodeCount != null) {
272280
builder.field(CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkerNodeCount);
273281
}
282+
if (planningWorkerNodes != null && planningWorkerNodes.length > 0) {
283+
builder.field(PLANNING_WORKER_NODES_FIELD, planningWorkerNodes);
284+
}
274285
builder.endObject();
275286
return builder;
276287
}
@@ -300,6 +311,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
300311
Integer totalChunks = null;
301312
Integer planningWorkerNodeCount = null;
302313
Integer currentWorkerNodeCount = null;
314+
List<String> planningWorkerNodes = new ArrayList<>();
303315

304316
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
305317
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -361,6 +373,12 @@ public static MLModel parse(XContentParser parser) throws IOException {
361373
case CURRENT_WORKER_NODE_COUNT_FIELD:
362374
currentWorkerNodeCount = parser.intValue();
363375
break;
376+
case PLANNING_WORKER_NODES_FIELD:
377+
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
378+
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
379+
planningWorkerNodes.add(parser.text());
380+
}
381+
break;
364382
case CREATED_TIME_FIELD:
365383
createdTime = Instant.ofEpochMilli(parser.longValue());
366384
break;
@@ -403,6 +421,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
403421
.totalChunks(totalChunks)
404422
.planningWorkerNodeCount(planningWorkerNodeCount)
405423
.currentWorkerNodeCount(currentWorkerNodeCount)
424+
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
406425
.build();
407426
}
408427

plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,19 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
254254
mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOAD_FAILED));
255255
});
256256

257+
List<String> workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList());
257258
mlModelManager
258259
.updateModel(
259260
modelId,
260261
ImmutableMap
261-
.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOADING, MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, eligibleNodes.size()),
262+
.of(
263+
MLModel.MODEL_STATE_FIELD,
264+
MLModelState.LOADING,
265+
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
266+
eligibleNodes.size(),
267+
MLModel.PLANNING_WORKER_NODES_FIELD,
268+
workerNodes
269+
),
262270
ActionListener
263271
.wrap(
264272
r -> client.execute(MLLoadModelOnNodeAction.INSTANCE, loadModelRequest, actionListener),

plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010

1111
import java.lang.reflect.Field;
1212
import java.nio.file.Path;
13+
import java.util.Arrays;
1314
import java.util.List;
15+
import java.util.Map;
1416
import java.util.concurrent.ExecutorService;
1517

1618
import org.junit.Before;
19+
import org.mockito.ArgumentCaptor;
1720
import org.mockito.InjectMocks;
1821
import org.mockito.Mock;
1922
import org.mockito.Mockito;
@@ -231,17 +234,25 @@ public void testUpdateModelLoadStatusAndTriggerOnNodesAction_success() throws No
231234

232235
when(mlTaskManager.contains(anyString())).thenReturn(true);
233236

237+
DiscoveryNode discoveryNode = mock(DiscoveryNode.class);
238+
when(discoveryNode.getId()).thenReturn("node1");
234239
transportLoadModelAction
235240
.updateModelLoadStatusAndTriggerOnNodesAction(
236241
modelId,
237242
"mock_task_id",
238243
mlModel,
239244
localNodeId,
240245
mlTask,
241-
eligibleNodes,
246+
Arrays.asList(discoveryNode),
242247
FunctionName.ANOMALY_LOCALIZATION
243248
);
244249
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
250+
251+
ArgumentCaptor<Map<String, Object>> captor = ArgumentCaptor.forClass(Map.class);
252+
verify(mlModelManager).updateModel(anyString(), captor.capture(), any());
253+
Map<String, Object> map = captor.getValue();
254+
assertNotNull(map.get(MLModel.PLANNING_WORKER_NODES_FIELD));
255+
assertEquals(1, (((List) map.get(MLModel.PLANNING_WORKER_NODES_FIELD)).size()));
245256
}
246257

247258
public void testUpdateModelLoadStatusAndTriggerOnNodesAction_whenMLTaskManagerThrowException_ListenerOnFailureExecuted() {

0 commit comments

Comments
 (0)