Skip to content

Commit b11c15b

Browse files
authored
[ML] Adding new trained model allocation service (#75778)
Adds a new service for trained model allocation to nodes. Initially, this only supports PyTorch models and simply allocates to nodes with the ML roles. Design is fairly simple: - A master node service runs allowing for new allocations to be created/updated/deleted from cluster state - A node service runs listening to updates referencing the local node + any models it may have allocated and updates accordingly. This type of service sort of splits the difference between the logic of shard allocation and persistent tasks. Neither really fully addressed the need here.
1 parent 10a1d27 commit b11c15b

File tree

37 files changed

+3964
-314
lines changed

37 files changed

+3964
-314
lines changed

.idea/inspectionProfiles/Project_Default.xml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.action;
9+
10+
import org.elasticsearch.action.ActionRequestValidationException;
11+
import org.elasticsearch.action.ActionResponse;
12+
import org.elasticsearch.action.ActionType;
13+
import org.elasticsearch.action.support.master.MasterNodeRequest;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
17+
import org.elasticsearch.common.xcontent.ParseField;
18+
import org.elasticsearch.common.xcontent.ToXContentObject;
19+
import org.elasticsearch.common.xcontent.XContentBuilder;
20+
import org.elasticsearch.common.xcontent.XContentParser;
21+
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
22+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
23+
24+
import java.io.IOException;
25+
import java.util.Objects;
26+
27+
public class CreateTrainedModelAllocationAction extends ActionType<CreateTrainedModelAllocationAction.Response> {
28+
public static final CreateTrainedModelAllocationAction INSTANCE = new CreateTrainedModelAllocationAction();
29+
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/create";
30+
31+
private CreateTrainedModelAllocationAction() {
32+
super(NAME, CreateTrainedModelAllocationAction.Response::new);
33+
}
34+
35+
public static class Request extends MasterNodeRequest<Request> {
36+
private final StartTrainedModelDeploymentAction.TaskParams taskParams;
37+
38+
public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) {
39+
this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams");
40+
}
41+
42+
public Request(StreamInput in) throws IOException {
43+
super(in);
44+
this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
45+
}
46+
47+
@Override
48+
public ActionRequestValidationException validate() {
49+
return null;
50+
}
51+
52+
@Override
53+
public void writeTo(StreamOutput out) throws IOException {
54+
super.writeTo(out);
55+
taskParams.writeTo(out);
56+
}
57+
58+
@Override
59+
public boolean equals(Object o) {
60+
if (this == o) return true;
61+
if (o == null || getClass() != o.getClass()) return false;
62+
Request request = (Request) o;
63+
return Objects.equals(taskParams, request.taskParams);
64+
}
65+
66+
@Override
67+
public int hashCode() {
68+
return Objects.hash(taskParams);
69+
}
70+
71+
public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
72+
return taskParams;
73+
}
74+
}
75+
76+
public static class Response extends ActionResponse implements ToXContentObject {
77+
78+
private static final ParseField ALLOCATION = new ParseField("allocation");
79+
80+
private static final ConstructingObjectParser<Response, Void> PARSER = new ConstructingObjectParser<>(
81+
"create_trained_model_allocation_response",
82+
a -> new Response((TrainedModelAllocation) a[0])
83+
);
84+
static {
85+
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> TrainedModelAllocation.fromXContent(p), ALLOCATION);
86+
}
87+
static Response fromXContent(XContentParser parser) {
88+
return PARSER.apply(parser, null);
89+
}
90+
91+
private final TrainedModelAllocation trainedModelAllocation;
92+
93+
public Response(TrainedModelAllocation trainedModelAllocation) {
94+
this.trainedModelAllocation = trainedModelAllocation;
95+
}
96+
97+
public Response(StreamInput in) throws IOException {
98+
super(in);
99+
this.trainedModelAllocation = new TrainedModelAllocation(in);
100+
}
101+
102+
@Override
103+
public void writeTo(StreamOutput out) throws IOException {
104+
trainedModelAllocation.writeTo(out);
105+
}
106+
107+
public TrainedModelAllocation getTrainedModelAllocation() {
108+
return trainedModelAllocation;
109+
}
110+
111+
@Override
112+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
113+
builder.startObject();
114+
builder.field("allocation", trainedModelAllocation);
115+
builder.endObject();
116+
return builder;
117+
}
118+
119+
@Override
120+
public boolean equals(Object o) {
121+
if (this == o) return true;
122+
if (o == null || getClass() != o.getClass()) return false;
123+
Response response = (Response) o;
124+
return Objects.equals(trainedModelAllocation, response.trainedModelAllocation);
125+
}
126+
127+
@Override
128+
public int hashCode() {
129+
return Objects.hash(trainedModelAllocation);
130+
}
131+
}
132+
133+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.action;
9+
10+
import org.elasticsearch.action.ActionRequestValidationException;
11+
import org.elasticsearch.action.ActionType;
12+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
13+
import org.elasticsearch.action.support.master.MasterNodeRequest;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17+
18+
import java.io.IOException;
19+
import java.util.Objects;
20+
21+
public class DeleteTrainedModelAllocationAction extends ActionType<AcknowledgedResponse> {
22+
public static final DeleteTrainedModelAllocationAction INSTANCE = new DeleteTrainedModelAllocationAction();
23+
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/delete";
24+
25+
private DeleteTrainedModelAllocationAction() {
26+
super(NAME, AcknowledgedResponse::readFrom);
27+
}
28+
29+
public static class Request extends MasterNodeRequest<Request> {
30+
private final String modelId;
31+
32+
public Request(String modelId) {
33+
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
34+
}
35+
36+
public Request(StreamInput in) throws IOException {
37+
super(in);
38+
this.modelId = in.readString();
39+
}
40+
41+
public String getModelId() {
42+
return modelId;
43+
}
44+
45+
@Override
46+
public ActionRequestValidationException validate() {
47+
return null;
48+
}
49+
50+
@Override
51+
public void writeTo(StreamOutput out) throws IOException {
52+
super.writeTo(out);
53+
out.writeString(modelId);
54+
}
55+
56+
@Override
57+
public boolean equals(Object o) {
58+
if (this == o) return true;
59+
if (o == null || getClass() != o.getClass()) return false;
60+
Request request = (Request) o;
61+
return Objects.equals(modelId, request.modelId);
62+
}
63+
64+
@Override
65+
public int hashCode() {
66+
return Objects.hash(modelId);
67+
}
68+
}
69+
70+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
import org.elasticsearch.action.ActionRequestValidationException;
1212
import org.elasticsearch.action.ActionType;
1313
import org.elasticsearch.action.support.master.MasterNodeRequest;
14+
import org.elasticsearch.cluster.node.DiscoveryNode;
15+
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
1416
import org.elasticsearch.common.unit.ByteSizeValue;
17+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1518
import org.elasticsearch.common.xcontent.ParseField;
1619
import org.elasticsearch.common.Strings;
1720
import org.elasticsearch.common.io.stream.StreamInput;
1821
import org.elasticsearch.common.io.stream.StreamOutput;
22+
import org.elasticsearch.common.xcontent.XContentParser;
1923
import org.elasticsearch.core.TimeValue;
2024
import org.elasticsearch.common.xcontent.ToXContentObject;
2125
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -31,15 +35,15 @@
3135
import java.util.Objects;
3236
import java.util.concurrent.TimeUnit;
3337

34-
public class StartTrainedModelDeploymentAction extends ActionType<NodeAcknowledgedResponse> {
38+
public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAllocationAction.Response> {
3539

3640
public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
3741
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start";
3842

3943
public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS);
4044

4145
public StartTrainedModelDeploymentAction() {
42-
super(NAME, NodeAcknowledgedResponse::new);
46+
super(NAME, CreateTrainedModelAllocationAction.Response::new);
4347
}
4448

4549
public static class Request extends MasterNodeRequest<Request> implements ToXContentObject {
@@ -120,9 +124,29 @@ public String toString() {
120124

121125
public static class TaskParams implements PersistentTaskParams, MlTaskParams {
122126

123-
public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
127+
// TODO add support for other roles? If so, it may have to be an instance method...
128+
// NOTE, whatever determines allocation should not be dynamically set on the node
129+
// Otherwise allocation logic might fail
130+
public static boolean mayAllocateToNode(DiscoveryNode node) {
131+
return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE);
132+
}
124133

134+
public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
125135
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
136+
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
137+
"trained_model_deployment_params",
138+
true,
139+
a -> new TaskParams((String)a[0], (String)a[1], (Long)a[2])
140+
);
141+
static {
142+
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
143+
PARSER.declareString(ConstructingObjectParser.constructorArg(), IndexLocation.INDEX);
144+
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
145+
}
146+
147+
public static TaskParams fromXContent(XContentParser parser) {
148+
return PARSER.apply(parser, null);
149+
}
126150

127151
/**
128152
* This has been found to be approximately 300MB on linux by manual testing.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.action;
9+
10+
import org.elasticsearch.action.ActionRequestValidationException;
11+
import org.elasticsearch.action.ActionType;
12+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
13+
import org.elasticsearch.action.support.master.MasterNodeRequest;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
17+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
18+
19+
import java.io.IOException;
20+
import java.util.Objects;
21+
22+
public class UpdateTrainedModelAllocationStateAction extends ActionType<AcknowledgedResponse> {
23+
public static final UpdateTrainedModelAllocationStateAction INSTANCE = new UpdateTrainedModelAllocationStateAction();
24+
public static final String NAME = "cluster:internal/xpack/ml/model_allocation/update";
25+
26+
private UpdateTrainedModelAllocationStateAction() {
27+
super(NAME, AcknowledgedResponse::readFrom);
28+
}
29+
30+
public static class Request extends MasterNodeRequest<Request> {
31+
private final String nodeId;
32+
private final String modelId;
33+
private final RoutingStateAndReason routingState;
34+
35+
public Request(String nodeId, String modelId, RoutingStateAndReason routingState) {
36+
this.nodeId = ExceptionsHelper.requireNonNull(nodeId, "node_id");
37+
this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
38+
this.routingState = ExceptionsHelper.requireNonNull(routingState, "routing_state");
39+
}
40+
41+
public Request(StreamInput in) throws IOException {
42+
super(in);
43+
this.nodeId = in.readString();
44+
this.modelId = in.readString();
45+
this.routingState = new RoutingStateAndReason(in);
46+
}
47+
48+
public String getNodeId() {
49+
return nodeId;
50+
}
51+
52+
public String getModelId() {
53+
return modelId;
54+
}
55+
56+
public RoutingStateAndReason getRoutingState() {
57+
return routingState;
58+
}
59+
60+
@Override
61+
public ActionRequestValidationException validate() {
62+
return null;
63+
}
64+
65+
@Override
66+
public void writeTo(StreamOutput out) throws IOException {
67+
super.writeTo(out);
68+
out.writeString(nodeId);
69+
out.writeString(modelId);
70+
routingState.writeTo(out);
71+
}
72+
73+
@Override
74+
public boolean equals(Object o) {
75+
if (this == o) return true;
76+
if (o == null || getClass() != o.getClass()) return false;
77+
Request request = (Request) o;
78+
return Objects.equals(nodeId, request.nodeId)
79+
&& Objects.equals(modelId, request.modelId)
80+
&& Objects.equals(routingState, request.routingState);
81+
}
82+
83+
@Override
84+
public int hashCode() {
85+
return Objects.hash(nodeId, modelId, routingState);
86+
}
87+
88+
@Override
89+
public String toString() {
90+
return "Request{" + "nodeId='" + nodeId + '\'' + ", modelId='" + modelId + '\'' + ", routingState=" + routingState + '}';
91+
}
92+
}
93+
94+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.allocation;
9+
10+
import java.util.Locale;
11+
12+
public enum AllocationState {
13+
STARTED,
14+
STOPPING;
15+
16+
public static AllocationState fromString(String value) {
17+
return valueOf(value.toUpperCase(Locale.ROOT));
18+
}
19+
20+
@Override
21+
public String toString() {
22+
return name().toLowerCase(Locale.ROOT);
23+
}
24+
}

0 commit comments

Comments
 (0)