Skip to content

Commit 8ba4b14

Browse files
authored
[ML] Fix issues in dynamically reading the number of allocations (#115095)
The GET inference API which should dynamically update the num_allocations field with the actual number from the deployed model. This fixes a bug using the grouped action listener with 0 requests and a second reading the wrong field
1 parent 96dfa1c commit 8ba4b14

File tree

7 files changed

+83
-10
lines changed

7 files changed

+83
-10
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,55 @@ public void testModelIdDoesNotMatch() throws IOException {
109109
);
110110
}
111111

112+
public void testNumAllocationsIsUpdated() throws IOException {
113+
var modelId = "update_num_allocations";
114+
var deploymentId = modelId;
115+
116+
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
117+
var response = startMlNodeDeploymemnt(modelId, deploymentId);
118+
assertOkOrCreated(response);
119+
120+
var inferenceId = "test_num_allocations_updated";
121+
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
122+
var serviceSettings = putModel.get("service_settings");
123+
assertThat(
124+
putModel.toString(),
125+
serviceSettings,
126+
is(
127+
Map.of(
128+
"num_allocations",
129+
1,
130+
"num_threads",
131+
1,
132+
"model_id",
133+
"update_num_allocations",
134+
"deployment_id",
135+
"update_num_allocations"
136+
)
137+
)
138+
);
139+
140+
assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
141+
142+
var updatedServiceSettings = getModel(inferenceId).get("service_settings");
143+
assertThat(
144+
updatedServiceSettings.toString(),
145+
updatedServiceSettings,
146+
is(
147+
Map.of(
148+
"num_allocations",
149+
2,
150+
"num_threads",
151+
1,
152+
"model_id",
153+
"update_num_allocations",
154+
"deployment_id",
155+
"update_num_allocations"
156+
)
157+
)
158+
);
159+
}
160+
112161
private String endpointConfig(String deploymentId) {
113162
return Strings.format("""
114163
{
@@ -147,6 +196,20 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
147196
return client().performRequest(request);
148197
}
149198

199+
private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
200+
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";
201+
202+
var body = Strings.format("""
203+
{
204+
"number_of_allocations": %d
205+
}
206+
""", numAllocations);
207+
208+
Request request = new Request("POST", endPoint);
209+
request.setJsonEntity(body);
210+
return client().performRequest(request);
211+
}
212+
150213
protected void stopMlNodeDeployment(String deploymentId) throws IOException {
151214
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
152215
Request request = new Request("POST", endpoint);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.stream.Stream;
2525

2626
import static org.hamcrest.Matchers.containsString;
27+
import static org.hamcrest.Matchers.empty;
2728
import static org.hamcrest.Matchers.equalTo;
2829
import static org.hamcrest.Matchers.equalToIgnoringCase;
2930
import static org.hamcrest.Matchers.hasSize;
@@ -326,4 +327,9 @@ public void testSupportedStream() throws Exception {
326327
deleteModel(modelId);
327328
}
328329
}
330+
331+
public void testGetZeroModels() throws IOException {
332+
var models = getModels("_all", TaskType.RERANK);
333+
assertThat(models, empty());
334+
}
329335
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceM
126126
}
127127

128128
private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
129+
if (unparsedModels.isEmpty()) {
130+
listener.onResponse(new GetInferenceModelAction.Response(List.of()));
131+
return;
132+
}
133+
129134
var parsedModelsByService = new HashMap<String, List<Model>>();
130135
try {
131136
for (var unparsedModel : unparsedModels) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,7 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
9292
}
9393

9494
public void updateNumAllocation(Integer numAllocations) {
95-
this.internalServiceSettings = new ElasticsearchInternalServiceSettings(
96-
numAllocations,
97-
this.internalServiceSettings.getNumThreads(),
98-
this.internalServiceSettings.modelId(),
99-
this.internalServiceSettings.getAdaptiveAllocationsSettings()
100-
);
95+
this.internalServiceSettings.setNumAllocations(numAllocations);
10196
}
10297

10398
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
781781
var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
782782
for (var model : models) {
783783
if (model instanceof ElasticsearchInternalModel esModel) {
784-
modelsByDeploymentIds.put(esModel.internalServiceSettings.deloymentId(), esModel);
784+
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
785785
} else {
786786
listener.onFailure(
787787
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
3939
public static final String DEPLOYMENT_ID = "deployment_id";
4040
public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations";
4141

42-
private final Integer numAllocations;
42+
private Integer numAllocations;
4343
private final int numThreads;
4444
private final String modelId;
4545
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
@@ -172,6 +172,10 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
172172
: null;
173173
}
174174

175+
public void setNumAllocations(Integer numAllocations) {
176+
this.numAllocations = numAllocations;
177+
}
178+
175179
@Override
176180
public void writeTo(StreamOutput out) throws IOException {
177181
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ public void testUpdateNumAllocation() {
2222
);
2323

2424
model.updateNumAllocation(1);
25-
assertEquals(1, model.internalServiceSettings.getNumAllocations().intValue());
25+
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());
2626

2727
model.updateNumAllocation(null);
28-
assertNull(model.internalServiceSettings.getNumAllocations());
28+
assertNull(model.getServiceSettings().getNumAllocations());
2929
}
3030
}

0 commit comments

Comments
 (0)