| 
9 | 9 | 
 
  | 
10 | 10 | import org.elasticsearch.ElasticsearchStatusException;  | 
11 | 11 | import org.elasticsearch.action.ActionListener;  | 
 | 12 | +import org.elasticsearch.action.ActionRunnable;  | 
12 | 13 | import org.elasticsearch.action.support.ActionFilters;  | 
13 |  | -import org.elasticsearch.action.support.GroupedActionListener;  | 
14 | 14 | import org.elasticsearch.action.support.HandledTransportAction;  | 
15 | 15 | import org.elasticsearch.common.Strings;  | 
16 | 16 | import org.elasticsearch.common.util.concurrent.EsExecutors;  | 
17 | 17 | import org.elasticsearch.inference.InferenceServiceRegistry;  | 
18 |  | -import org.elasticsearch.inference.Model;  | 
 | 18 | +import org.elasticsearch.inference.ModelConfigurations;  | 
19 | 19 | import org.elasticsearch.inference.TaskType;  | 
20 | 20 | import org.elasticsearch.inference.UnparsedModel;  | 
21 | 21 | import org.elasticsearch.injection.guice.Inject;  | 
 | 
29 | 29 | import org.elasticsearch.xpack.inference.registry.ModelRegistry;  | 
30 | 30 | 
 
  | 
31 | 31 | import java.util.ArrayList;  | 
32 |  | -import java.util.Comparator;  | 
33 |  | -import java.util.HashMap;  | 
34 | 32 | import java.util.List;  | 
35 | 33 | import java.util.concurrent.Executor;  | 
36 |  | -import java.util.stream.Collectors;  | 
37 | 34 | 
 
  | 
38 | 35 | public class TransportGetInferenceModelAction extends HandledTransportAction<  | 
39 | 36 |     GetInferenceModelAction.Request,  | 
@@ -99,69 +96,38 @@ private void getSingleModel(  | 
99 | 96 | 
 
  | 
100 | 97 |             var model = service.get()  | 
101 | 98 |                 .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());  | 
102 |  | - | 
103 |  | -            service.get()  | 
104 |  | -                .updateModelsWithDynamicFields(  | 
105 |  | -                    List.of(model),  | 
106 |  | -                    delegate.delegateFailureAndWrap(  | 
107 |  | -                        (l2, updatedModels) -> l2.onResponse(  | 
108 |  | -                            new GetInferenceModelAction.Response(  | 
109 |  | -                                updatedModels.stream().map(Model::getConfigurations).collect(Collectors.toList())  | 
110 |  | -                            )  | 
111 |  | -                        )  | 
112 |  | -                    )  | 
113 |  | -                );  | 
 | 99 | +            delegate.onResponse(new GetInferenceModelAction.Response(List.of(model.getConfigurations())));  | 
114 | 100 |         }));  | 
115 | 101 |     }  | 
116 | 102 | 
 
  | 
117 | 103 |     private void getAllModels(ActionListener<GetInferenceModelAction.Response> listener) {  | 
118 |  | -        modelRegistry.getAllModels(listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener))));  | 
 | 104 | +        modelRegistry.getAllModels(  | 
 | 105 | +            listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))  | 
 | 106 | +        );  | 
119 | 107 |     }  | 
120 | 108 | 
 
  | 
121 | 109 |     private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {  | 
122 | 110 |         modelRegistry.getModelsByTaskType(  | 
123 | 111 |             taskType,  | 
124 |  | -            listener.delegateFailureAndWrap((l, models) -> executor.execute(() -> parseModels(models, listener)))  | 
 | 112 | +            listener.delegateFailureAndWrap((l, models) -> executor.execute(ActionRunnable.supply(l, () -> parseModels(models))))  | 
125 | 113 |         );  | 
126 | 114 |     }  | 
127 | 115 | 
 
  | 
128 |  | -    private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {  | 
129 |  | -        var parsedModelsByService = new HashMap<String, List<Model>>();  | 
130 |  | -        try {  | 
131 |  | -            for (var unparsedModel : unparsedModels) {  | 
132 |  | -                var service = serviceRegistry.getService(unparsedModel.service());  | 
133 |  | -                if (service.isEmpty()) {  | 
134 |  | -                    throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());  | 
135 |  | -                }  | 
136 |  | -                var list = parsedModelsByService.computeIfAbsent(service.get().name(), s -> new ArrayList<>());  | 
137 |  | -                list.add(  | 
138 |  | -                    service.get()  | 
139 |  | -                        .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())  | 
140 |  | -                );  | 
141 |  | -            }  | 
142 |  | - | 
143 |  | -            var groupedListener = new GroupedActionListener<List<Model>>(  | 
144 |  | -                parsedModelsByService.entrySet().size(),  | 
145 |  | -                listener.delegateFailureAndWrap((delegate, listOfListOfModels) -> {  | 
146 |  | -                    var modifiable = new ArrayList<Model>();  | 
147 |  | -                    for (var l : listOfListOfModels) {  | 
148 |  | -                        modifiable.addAll(l);  | 
149 |  | -                    }  | 
150 |  | -                    modifiable.sort(Comparator.comparing(Model::getInferenceEntityId));  | 
151 |  | -                    delegate.onResponse(  | 
152 |  | -                        new GetInferenceModelAction.Response(modifiable.stream().map(Model::getConfigurations).collect(Collectors.toList()))  | 
153 |  | -                    );  | 
154 |  | -                })  | 
155 |  | -            );  | 
 | 116 | +    private GetInferenceModelAction.Response parseModels(List<UnparsedModel> unparsedModels) {  | 
 | 117 | +        var parsedModels = new ArrayList<ModelConfigurations>();  | 
156 | 118 | 
 
  | 
157 |  | -            for (var entry : parsedModelsByService.entrySet()) {  | 
158 |  | -                serviceRegistry.getService(entry.getKey())  | 
159 |  | -                    .get()  // must be non-null to get this far  | 
160 |  | -                    .updateModelsWithDynamicFields(entry.getValue(), groupedListener);  | 
 | 119 | +        for (var unparsedModel : unparsedModels) {  | 
 | 120 | +            var service = serviceRegistry.getService(unparsedModel.service());  | 
 | 121 | +            if (service.isEmpty()) {  | 
 | 122 | +                throw serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());  | 
161 | 123 |             }  | 
162 |  | -        } catch (Exception e) {  | 
163 |  | -            listener.onFailure(e);  | 
 | 124 | +            parsedModels.add(  | 
 | 125 | +                service.get()  | 
 | 126 | +                    .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings())  | 
 | 127 | +                    .getConfigurations()  | 
 | 128 | +            );  | 
164 | 129 |         }  | 
 | 130 | +        return new GetInferenceModelAction.Response(parsedModels);  | 
165 | 131 |     }  | 
166 | 132 | 
 
  | 
167 | 133 |     private ElasticsearchStatusException serviceNotFoundException(String service, String inferenceId) {  | 
 | 
0 commit comments