Skip to content

Commit 2697f85

Browse files
authored
Revert "[ML] Dynamically get of num allocations (#114636)" (#114861)
This reverts commit 8040fbb.
1 parent 4fa8485 commit 2697f85

File tree

7 files changed

+24
-148
lines changed

7 files changed

+24
-148
lines changed

docs/changelog/114636.yaml

Lines changed: 0 additions & 5 deletions
This file was deleted.

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,4 @@ default List<DefaultConfigId> defaultConfigIds() {
210210
default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
211211
defaultsListener.onResponse(List.of());
212212
}
213-
214-
default void updateModelsWithDynamicFields(List<Model> model, ActionListener<List<Model>> listener) {
215-
listener.onResponse(model);
216-
}
217213
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java

Lines changed: 19 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.ActionRunnable;
1213
import org.elasticsearch.action.support.ActionFilters;
13-
import org.elasticsearch.action.support.GroupedActionListener;
1414
import org.elasticsearch.action.support.HandledTransportAction;
1515
import org.elasticsearch.common.Strings;
1616
import org.elasticsearch.common.util.concurrent.EsExecutors;
1717
import org.elasticsearch.inference.InferenceServiceRegistry;
18-
import org.elasticsearch.inference.Model;
18+
import org.elasticsearch.inference.ModelConfigurations;
1919
import org.elasticsearch.inference.TaskType;
2020
import org.elasticsearch.inference.UnparsedModel;
2121
import org.elasticsearch.injection.guice.Inject;
@@ -29,11 +29,8 @@
2929
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3030

3131
import java.util.ArrayList;
32-
import java.util.Comparator;
33-
import java.util.HashMap;
3432
import java.util.List;
3533
import java.util.concurrent.Executor;
36-
import java.util.stream.Collectors;
3734

3835
public class TransportGetInferenceModelAction extends HandledTransportAction<
3936
GetInferenceModelAction.Request,
@@ -99,69 +96,38 @@ private void getSingleModel(
9996

10097
var model = service.get()
10198
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
102-
103-
service.get()
104-
.updateModelsWithDynamicFields(
105-
List.of(model),
106-
delegate.delegateFailureAndWrap(
107-
(l2, updatedModels) -> l2.onResponse(
108-
new GetInferenceModelAction.Response(
109-
updatedModels.stream().map(Model::getConfigurations).collect(Collectors.toList())
110-
)
111-
)
112-
)
113-
);
99+
delegate.onResponse(new GetInferenceModelAction.Response(List.of(model.getConfigurations())));
114100
}));
115101
}
116102

117103
private void getAllModels(ActionListener<GetInferenceModelAction.Response> listener) {
118-
modelRegistry.getAllModels(listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener))));
104+
modelRegistry.getAllModels(
105+
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
106+
);
119107
}
120108

121109
private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {
122110
modelRegistry.getModelsByTaskType(
123111
taskType,
124-
listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))
112+
listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))
125113
);
126114
}
127115

128-
private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
129-
var parsedModelsByService = new HashMap<String, List<Model>>();
130-
try {
131-
for (var unparsedModel : unparsedModels) {
132-
var service = serviceRegistry.getService(unparsedModel.service());
133-
if (service.isEmpty()) {
134-
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
135-
}
136-
var list = parsedModelsByService.computeIfAbsent(service.get().name(), s -> new ArrayList<>());
137-
list.add(
138-
service.get()
139-
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
140-
);
141-
}
142-
143-
var groupedListener = new GroupedActionListener<List<Model>>(
144-
parsedModelsByService.entrySet().size(),
145-
listener.delegateFailureAndWrap((delegate, listOfListOfModels) -> {
146-
var modifiable = new ArrayList<Model>();
147-
for (var l : listOfListOfModels) {
148-
modifiable.addAll(l);
149-
}
150-
modifiable.sort(Comparator.comparing(Model::getInferenceEntityId));
151-
delegate.onResponse(
152-
new GetInferenceModelAction.Response(modifiable.stream().map(Model::getConfigurations).collect(Collectors.toList()))
153-
);
154-
})
155-
);
116+
private GetInferenceModelAction.Response parseModels(List<UnparsedModel> unparsedModels) {
117+
var parsedModels = new ArrayList<ModelConfigurations>();
156118

157-
for (var entry : parsedModelsByService.entrySet()) {
158-
serviceRegistry.getService(entry.getKey())
159-
.get() // must be non-null to get this far
160-
.updateModelsWithDynamicFields(entry.getValue(), groupedListener);
119+
for (var unparsedModel : unparsedModels) {
120+
var service = serviceRegistry.getService(unparsedModel.service());
121+
if (service.isEmpty()) {
122+
throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
161123
}
162-
} catch (Exception e) {
163-
listener.onFailure(e);
124+
parsedModels.add(
125+
service.get()
126+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())
127+
.getConfigurations()
128+
);
164129
}
130+
return new GetInferenceModelAction.Response(parsedModels);
165131
}
166132

167133
private ElasticsearchStatusException serviceNotFoundException(String service, String inferenceId) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
public abstract class ElasticsearchInternalModel extends Model {
2323

24-
protected ElasticsearchInternalServiceSettings internalServiceSettings;
24+
protected final ElasticsearchInternalServiceSettings internalServiceSettings;
2525

2626
public ElasticsearchInternalModel(
2727
String inferenceEntityId,
@@ -91,15 +91,6 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
9191
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
9292
}
9393

94-
public void updateNumAllocation(Integer numAllocations) {
95-
this.internalServiceSettings = new ElasticsearchInternalServiceSettings(
96-
numAllocations,
97-
this.internalServiceSettings.getNumThreads(),
98-
this.internalServiceSettings.modelId(),
99-
this.internalServiceSettings.getAdaptiveAllocationsSettings()
100-
);
101-
}
102-
10394
@Override
10495
public String toString() {
10596
return Strings.toString(this.getConfigurations());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
3333
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
3434
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
35-
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
3635
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
3736
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
3837
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
@@ -57,7 +56,6 @@
5756
import java.util.ArrayList;
5857
import java.util.Collections;
5958
import java.util.EnumSet;
60-
import java.util.HashMap;
6159
import java.util.List;
6260
import java.util.Map;
6361
import java.util.Optional;
@@ -788,47 +786,11 @@ public List<DefaultConfigId> defaultConfigIds() {
788786
);
789787
}
790788

789+
/**
790+
* Default configurations that can be out of the box without creating an endpoint first.
791+
* @param defaultsListener Config listener
792+
*/
791793
@Override
792-
public void updateModelsWithDynamicFields(List<Model> models, ActionListener<List<Model>> listener) {
793-
var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
794-
for (var model : models) {
795-
if (model instanceof ElasticsearchInternalModel esModel) {
796-
modelsByDeploymentIds.put(esModel.internalServiceSettings.deloymentId(), esModel);
797-
} else {
798-
listener.onFailure(
799-
new ElasticsearchStatusException(
800-
"Cannot update model [{}] as it is not an Elasticsearch service model",
801-
RestStatus.INTERNAL_SERVER_ERROR,
802-
model.getInferenceEntityId()
803-
)
804-
);
805-
return;
806-
}
807-
}
808-
809-
if (modelsByDeploymentIds.isEmpty()) {
810-
listener.onResponse(models);
811-
return;
812-
}
813-
814-
String deploymentIds = String.join(",", modelsByDeploymentIds.keySet());
815-
client.execute(
816-
GetDeploymentStatsAction.INSTANCE,
817-
new GetDeploymentStatsAction.Request(deploymentIds),
818-
ActionListener.wrap(stats -> {
819-
for (var deploymentStats : stats.getStats().results()) {
820-
var model = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
821-
model.updateNumAllocation(deploymentStats.getNumberOfAllocations());
822-
}
823-
listener.onResponse(new ArrayList<>(modelsByDeploymentIds.values()));
824-
}, e -> {
825-
logger.warn("Get deployment stats failed, cannot update the endpoint's number of allocations", e);
826-
// continue with the original response
827-
listener.onResponse(models);
828-
})
829-
);
830-
}
831-
832794
public void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
833795
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
834796
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,6 @@ public String modelId() {
194194
return modelId;
195195
}
196196

197-
public String deloymentId() {
198-
return modelId;
199-
}
200-
201197
public Integer getNumAllocations() {
202198
return numAllocations;
203199
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)