Skip to content

Commit 7647c79

Browse files
committed
[ML] Sync Inference with Trained Model stats
When the Trained Model stats are read, either during `GET _inference` or `PUT _inference`, the Inference stats are updated to reflected the Trained Model stats. Fix #130339
1 parent e07f9fe commit 7647c79

File tree

8 files changed

+57
-18
lines changed

8 files changed

+57
-18
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
114114
}
115115
}).<Boolean>andThen((l2, modelDidPut) -> {
116116
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
117-
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
117+
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(esModel, l2);
118118
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
119119
});
120120
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.core.TimeValue;
1212
import org.elasticsearch.inference.ChunkingSettings;
13-
import org.elasticsearch.inference.Model;
1413
import org.elasticsearch.inference.TaskType;
1514
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
1615
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -43,7 +42,7 @@ protected String modelNotFoundErrorMessage(String modelId) {
4342

4443
@Override
4544
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
46-
Model model,
45+
ElasticsearchInternalModel esModel,
4746
ActionListener<Boolean> listener
4847
) {
4948
throw new IllegalStateException("cannot start model that uses an existing deployment");

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.ResourceNotFoundException;
1111
import org.elasticsearch.action.ActionListener;
12-
import org.elasticsearch.inference.Model;
1312
import org.elasticsearch.inference.TaskType;
1413
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
1514
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -33,7 +32,7 @@ public ElasticRerankerServiceSettings getServiceSettings() {
3332

3433
@Override
3534
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
36-
Model model,
35+
ElasticsearchInternalModel esModel,
3736
ActionListener<Boolean> listener
3837
) {
3938

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import org.elasticsearch.rest.RestStatus;
2222
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
2323
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
24+
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
25+
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
2426
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2527

2628
import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;
@@ -85,20 +87,21 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA
8587
}
8688

8789
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
88-
Model model,
90+
ElasticsearchInternalModel esModel,
8991
ActionListener<Boolean> listener
9092
) {
9193
return new ActionListener<>() {
9294
@Override
9395
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
96+
esModel.updateServiceSettings(response.getTrainedModelAssignment());
9497
listener.onResponse(Boolean.TRUE);
9598
}
9699

97100
@Override
98101
public void onFailure(Exception e) {
99102
var cause = ExceptionsHelper.unwrapCause(e);
100103
if (cause instanceof ResourceNotFoundException) {
101-
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId())));
104+
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(esModel.internalServiceSettings.modelId())));
102105
return;
103106
} else if (cause instanceof ElasticsearchStatusException statusException) {
104107
if (statusException.status() == RestStatus.CONFLICT
@@ -128,8 +131,18 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
128131
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
129132
}
130133

131-
public void updateNumAllocations(Integer numAllocations) {
132-
this.internalServiceSettings.setNumAllocations(numAllocations);
134+
public void updateServiceSettings(AssignmentStats assignmentStats) {
135+
this.internalServiceSettings.setAllocations(
136+
assignmentStats.getNumberOfAllocations(),
137+
assignmentStats.getAdaptiveAllocationsSettings()
138+
);
139+
}
140+
141+
private void updateServiceSettings(TrainedModelAssignment trainedModelAssignment) {
142+
this.internalServiceSettings.setAllocations(
143+
this.internalServiceSettings.getNumAllocations(),
144+
trainedModelAssignment.getAdaptiveAllocationsSettings()
145+
);
133146
}
134147

135148
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
890890
ActionListener.wrap(stats -> {
891891
for (var deploymentStats : stats.getStats().results()) {
892892
var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
893-
modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations()));
893+
modelsForDeploymentId.forEach(model -> model.updateServiceSettings(deploymentStats));
894894
}
895895
var updatedModels = new ArrayList<Model>();
896896
modelsByDeploymentIds.values().forEach(updatedModels::addAll);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
4343
private Integer numAllocations;
4444
private final int numThreads;
4545
private final String modelId;
46-
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
46+
private AdaptiveAllocationsSettings adaptiveAllocationsSettings;
4747
private final String deploymentId;
4848

4949
public static ElasticsearchInternalServiceSettings fromPersistedMap(Map<String, Object> map) {
@@ -158,8 +158,9 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
158158
this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null;
159159
}
160160

161-
public void setNumAllocations(Integer numAllocations) {
161+
public void setAllocations(Integer numAllocations, @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
162162
this.numAllocations = numAllocations;
163+
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
163164
}
164165

165166
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,12 @@
108108
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME;
109109
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;
110110
import static org.hamcrest.Matchers.containsString;
111+
import static org.hamcrest.Matchers.equalTo;
111112
import static org.hamcrest.Matchers.hasSize;
112113
import static org.hamcrest.Matchers.instanceOf;
113114
import static org.hamcrest.Matchers.is;
114115
import static org.mockito.ArgumentMatchers.any;
116+
import static org.mockito.ArgumentMatchers.assertArg;
115117
import static org.mockito.ArgumentMatchers.eq;
116118
import static org.mockito.ArgumentMatchers.same;
117119
import static org.mockito.Mockito.doAnswer;
@@ -1767,7 +1769,9 @@ private void testUpdateModelsWithDynamicFields(Map<String, List<Model>> modelsBy
17671769
modelsByDeploymentId.forEach((deploymentId, models) -> {
17681770
var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId);
17691771
models.forEach(model -> {
1770-
verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations);
1772+
verify((ElasticsearchInternalModel) model).updateServiceSettings(assertArg(assignmentStats -> {
1773+
assertThat(assignmentStats.getNumberOfAllocations(), equalTo(expectedNumberOfAllocations));
1774+
}));
17711775
verify((ElasticsearchInternalModel) model).mlNodeDeploymentId();
17721776
verifyNoMoreInteractions(model);
17731777
});
@@ -1858,7 +1862,9 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException {
18581862
var latch = new CountDownLatch(1);
18591863
service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown()));
18601864
assertTrue(latch.await(30, TimeUnit.SECONDS));
1861-
verify(model).updateNumAllocations(3);
1865+
verify(model).updateServiceSettings(
1866+
assertArg(assignmentStats -> { assertThat(assignmentStats.getNumberOfAllocations(), equalTo(3)); })
1867+
);
18621868
}
18631869
}
18641870

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,17 @@
77

88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

10+
import org.elasticsearch.action.ActionListener;
1011
import org.elasticsearch.inference.TaskType;
1112
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
14+
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
15+
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
16+
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;
17+
18+
import static org.hamcrest.Matchers.equalTo;
19+
import static org.mockito.Mockito.mock;
20+
import static org.mockito.Mockito.when;
1221

1322
public class ElserInternalModelTests extends ESTestCase {
1423
public void testUpdateNumAllocation() {
@@ -21,10 +30,22 @@ public void testUpdateNumAllocation() {
2130
null
2231
);
2332

24-
model.updateNumAllocations(1);
25-
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());
33+
AssignmentStats assignmentStats = mock();
34+
when(assignmentStats.getNumberOfAllocations()).thenReturn(1);
35+
model.updateServiceSettings(assignmentStats);
36+
37+
assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
38+
assertNull(model.getServiceSettings().getAdaptiveAllocationsSettings());
2639

27-
model.updateNumAllocations(null);
28-
assertNull(model.getServiceSettings().getNumAllocations());
40+
TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentTests.randomInstance();
41+
CreateTrainedModelAssignmentAction.Response response = mock();
42+
when(response.getTrainedModelAssignment()).thenReturn(trainedModelAssignment);
43+
model.getCreateTrainedModelAssignmentActionListener(model, ActionListener.noop()).onResponse(response);
44+
45+
assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
46+
assertThat(
47+
model.getServiceSettings().getAdaptiveAllocationsSettings(),
48+
equalTo(trainedModelAssignment.getAdaptiveAllocationsSettings())
49+
);
2950
}
3051
}

0 commit comments

Comments
 (0)