Skip to content

Commit 2072f52

Browse files
author
Max Hniebergall
committed
Put default endpoitns behind feature flag
1 parent 239038b commit 2072f52

File tree

6 files changed

+174
-117
lines changed

6 files changed

+174
-117
lines changed

test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/distribution/DefaultDistributionDescriptor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public boolean isSnapshot() {
3535
}
3636

3737
public Path getDistributionDir() {
38-
return distributionDir.resolve("elasticsearch-" + version + (snapshot ? "-SNAPSHOT" : ""));
38+
return distributionDir.resolve("elasticsearch-" + version);
3939
}
4040

4141
public DistributionType getType() {

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.threadpool.ThreadPool;
3232
import org.elasticsearch.xcontent.ToXContentObject;
3333
import org.elasticsearch.xcontent.XContentBuilder;
34+
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
3435
import org.elasticsearch.xpack.inference.InferencePlugin;
3536
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
3637
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -316,8 +317,9 @@ public void testGetAllModels_WithDefaults() throws Exception {
316317
listener.onResponse(defaultConfigs);
317318
return Void.TYPE;
318319
}).when(service).defaultConfigs(any());
319-
320-
defaultIds.forEach(modelRegistry::addDefaultIds);
320+
if (DefaultElserFeatureFlag.isEnabled()) {
321+
defaultIds.forEach(modelRegistry::addDefaultIds);
322+
}
321323

322324
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
323325
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 93 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.elasticsearch.xcontent.XContentBuilder;
5050
import org.elasticsearch.xcontent.XContentFactory;
5151
import org.elasticsearch.xpack.core.ClientHelper;
52+
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
5253
import org.elasticsearch.xpack.inference.InferenceIndex;
5354
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
5455
import org.elasticsearch.xpack.inference.services.ServiceUtils;
@@ -117,19 +118,23 @@ public ModelRegistry(Client client) {
117118
* @param defaultConfigIds The defaults
118119
*/
119120
public void addDefaultIds(InferenceService.DefaultConfigId defaultConfigIds) {
120-
var matched = idMatchedDefault(defaultConfigIds.inferenceId(), this.defaultConfigIds);
121-
if (matched.isPresent()) {
122-
throw new IllegalStateException(
123-
"Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
124-
+ defaultConfigIds.inferenceId()
125-
+ "] declared by service ["
126-
+ defaultConfigIds.service().name()
127-
+ "]. The inference Id is already use by ["
128-
+ matched.get().service().name()
129-
+ "] service."
130-
);
121+
if (DefaultElserFeatureFlag.isEnabled()) {
122+
var matched = idMatchedDefault(defaultConfigIds.inferenceId(), this.defaultConfigIds);
123+
if (matched.isPresent()) {
124+
throw new IllegalStateException(
125+
"Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
126+
+ defaultConfigIds.inferenceId()
127+
+ "] declared by service ["
128+
+ defaultConfigIds.service().name()
129+
+ "]. The inference Id is already use by ["
130+
+ matched.get().service().name()
131+
+ "] service."
132+
);
133+
}
134+
this.defaultConfigIds.add(defaultConfigIds);
135+
} else {
136+
logger.error("Attempted to addDefaultIds [{}] with the feature flag disabled", defaultConfigIds.inferenceId());
131137
}
132-
this.defaultConfigIds.add(defaultConfigIds);
133138
}
134139

135140
/**
@@ -142,7 +147,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
142147
// There should be a hit for the configurations
143148
if (searchResponse.getHits().getHits().length == 0) {
144149
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
145-
if (maybeDefault.isPresent()) {
150+
if (DefaultElserFeatureFlag.isEnabled() && maybeDefault.isPresent()) {
146151
getDefaultConfig(true, maybeDefault.get(), listener);
147152
} else {
148153
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
@@ -173,7 +178,7 @@ public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> lis
173178
// There should be a hit for the configurations
174179
if (searchResponse.getHits().getHits().length == 0) {
175180
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
176-
if (maybeDefault.isPresent()) {
181+
if (DefaultElserFeatureFlag.isEnabled() && maybeDefault.isPresent()) {
177182
getDefaultConfig(true, maybeDefault.get(), listener);
178183
} else {
179184
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
@@ -209,8 +214,12 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt
209214
public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedModel>> listener) {
210215
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
211216
var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
212-
var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
213-
addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
217+
if (DefaultElserFeatureFlag.isEnabled()) {
218+
var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
219+
addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
220+
} else {
221+
delegate.onResponse(modelConfigs);
222+
}
214223
});
215224

216225
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString()));
@@ -240,7 +249,11 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM
240249
public void getAllModels(boolean persistDefaultEndpoints, ActionListener<List<UnparsedModel>> listener) {
241250
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
242251
var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
243-
addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
252+
if (DefaultElserFeatureFlag.isEnabled()) {
253+
addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
254+
} else {
255+
delegate.onResponse(foundConfigs);
256+
}
244257
});
245258

246259
// In theory the index should only contain model config documents
@@ -264,26 +277,32 @@ private void addAllDefaultConfigsIfMissing(
264277
List<InferenceService.DefaultConfigId> matchedDefaults,
265278
ActionListener<List<UnparsedModel>> listener
266279
) {
267-
var foundIds = foundConfigs.stream().map(UnparsedModel::inferenceEntityId).collect(Collectors.toSet());
268-
var missing = matchedDefaults.stream().filter(d -> foundIds.contains(d.inferenceId()) == false).toList();
280+
if (DefaultElserFeatureFlag.isEnabled()) {
269281

270-
if (missing.isEmpty()) {
271-
listener.onResponse(foundConfigs);
272-
} else {
273-
var groupedListener = new GroupedActionListener<UnparsedModel>(
274-
missing.size(),
275-
listener.delegateFailure((delegate, listOfModels) -> {
276-
var allConfigs = new ArrayList<UnparsedModel>();
277-
allConfigs.addAll(foundConfigs);
278-
allConfigs.addAll(listOfModels);
279-
allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
280-
delegate.onResponse(allConfigs);
281-
})
282-
);
282+
var foundIds = foundConfigs.stream().map(UnparsedModel::inferenceEntityId).collect(Collectors.toSet());
283+
var missing = matchedDefaults.stream().filter(d -> foundIds.contains(d.inferenceId()) == false).toList();
284+
285+
if (missing.isEmpty()) {
286+
listener.onResponse(foundConfigs);
287+
} else {
288+
var groupedListener = new GroupedActionListener<UnparsedModel>(
289+
missing.size(),
290+
listener.delegateFailure((delegate, listOfModels) -> {
291+
var allConfigs = new ArrayList<UnparsedModel>();
292+
allConfigs.addAll(foundConfigs);
293+
allConfigs.addAll(listOfModels);
294+
allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
295+
delegate.onResponse(allConfigs);
296+
})
297+
);
283298

284-
for (var required : missing) {
285-
getDefaultConfig(persistDefaultEndpoints, required, groupedListener);
299+
for (var required : missing) {
300+
getDefaultConfig(persistDefaultEndpoints, required, groupedListener);
301+
}
286302
}
303+
} else {
304+
logger.error("Attempted to add default configs with the feature flag disabled");
305+
assert false;
287306
}
288307
}
289308

@@ -292,40 +311,52 @@ private void getDefaultConfig(
292311
InferenceService.DefaultConfigId defaultConfig,
293312
ActionListener<UnparsedModel> listener
294313
) {
295-
defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
296-
boolean foundModel = false;
297-
for (var m : models) {
298-
if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
299-
foundModel = true;
300-
if (persistDefaultEndpoints) {
301-
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
302-
} else {
303-
listener.onResponse(modelToUnparsedModel(m));
314+
if (DefaultElserFeatureFlag.isEnabled()) {
315+
316+
defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
317+
boolean foundModel = false;
318+
for (var m : models) {
319+
if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
320+
foundModel = true;
321+
if (persistDefaultEndpoints) {
322+
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
323+
} else {
324+
listener.onResponse(modelToUnparsedModel(m));
325+
}
326+
break;
304327
}
305-
break;
306328
}
307-
}
308329

309-
if (foundModel == false) {
310-
listener.onFailure(
311-
new IllegalStateException("Configuration not found for default inference id [" + defaultConfig.inferenceId() + "]")
312-
);
313-
}
314-
}));
330+
if (foundModel == false) {
331+
listener.onFailure(
332+
new IllegalStateException("Configuration not found for default inference id [" + defaultConfig.inferenceId() + "]")
333+
);
334+
}
335+
}));
336+
} else {
337+
logger.error("Attempted to get default configs with the feature flag disabled");
338+
assert false;
339+
}
315340
}
316341

317342
private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
318-
var responseListener = ActionListener.<Boolean>wrap(success -> {
319-
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
320-
}, exception -> {
321-
if (exception instanceof ResourceAlreadyExistsException) {
322-
logger.debug("Default inference id [{}] already exists", preconfigured.getInferenceEntityId());
323-
} else {
324-
logger.error("Failed to store default inference id [" + preconfigured.getInferenceEntityId() + "]", exception);
325-
}
326-
});
343+
if (DefaultElserFeatureFlag.isEnabled()) {
344+
345+
var responseListener = ActionListener.<Boolean>wrap(success -> {
346+
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
347+
}, exception -> {
348+
if (exception instanceof ResourceAlreadyExistsException) {
349+
logger.debug("Default inference id [{}] already exists", preconfigured.getInferenceEntityId());
350+
} else {
351+
logger.error("Failed to store default inference id [" + preconfigured.getInferenceEntityId() + "]", exception);
352+
}
353+
});
327354

328-
storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter));
355+
storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter));
356+
} else {
357+
logger.error("Attempted to store default endpoint with the feature flag disabled");
358+
assert false;
359+
}
329360
}
330361

331362
private ArrayList<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
@@ -673,6 +704,7 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
673704
TaskType taskType,
674705
List<InferenceService.DefaultConfigId> defaultConfigIds
675706
) {
707+
assert DefaultElserFeatureFlag.isEnabled();
676708
return defaultConfigIds.stream()
677709
.filter(defaultConfigId -> defaultConfigId.taskType().equals(taskType))
678710
.collect(Collectors.toList());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
4848
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
4949
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
50+
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
5051
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
5152
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
5253
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -113,7 +114,7 @@ public void parseRequestConfig(
113114
Map<String, Object> config,
114115
ActionListener<Model> modelListener
115116
) {
116-
if (inferenceEntityId.equals(DEFAULT_ELSER_ID)) {
117+
if (DefaultElserFeatureFlag.isEnabled() && inferenceEntityId.equals(DEFAULT_ELSER_ID)) {
117118
modelListener.onFailure(
118119
new ElasticsearchStatusException(
119120
"[{}] is a reserved inference Id. Cannot create a new inference endpoint with a reserved Id",
@@ -769,6 +770,8 @@ private RankedDocsResults textSimilarityResultsToRankedDocs(
769770
}
770771

771772
public List<DefaultConfigId> defaultConfigIds() {
773+
assert DefaultElserFeatureFlag.isEnabled();
774+
772775
return List.of(
773776
new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
774777
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
@@ -817,13 +820,18 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
817820
}
818821

819822
public void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
820-
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
821-
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {
822-
defaultsListener.onResponse(defaultConfigsLinuxOptimized());
823-
} else {
824-
defaultsListener.onResponse(defaultConfigsPlatfromAgnostic());
825-
}
826-
}));
823+
if (DefaultElserFeatureFlag.isEnabled()) {
824+
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
825+
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {
826+
defaultsListener.onResponse(defaultConfigsLinuxOptimized());
827+
} else {
828+
defaultsListener.onResponse(defaultConfigsPlatfromAgnostic());
829+
}
830+
}));
831+
} else {
832+
logger.error("Attempted to add default configs with the feature flag disabled");
833+
assert false;
834+
}
827835
}
828836

829837
private List<Model> defaultConfigsLinuxOptimized() {
@@ -865,6 +873,7 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
865873

866874
@Override
867875
boolean isDefaultId(String inferenceId) {
876+
assert DefaultElserFeatureFlag.isEnabled();
868877
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
869878
}
870879

0 commit comments

Comments
 (0)