Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand Down Expand Up @@ -316,8 +317,9 @@ public void testGetAllModels_WithDefaults() throws Exception {
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());

defaultIds.forEach(modelRegistry::addDefaultIds);
if (DefaultElserFeatureFlag.isEnabled()) {
defaultIds.forEach(modelRegistry::addDefaultIds);
}

AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
import org.elasticsearch.xpack.inference.InferenceIndex;
import org.elasticsearch.xpack.inference.InferenceSecretsIndex;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
Expand Down Expand Up @@ -117,19 +118,23 @@ public ModelRegistry(Client client) {
* @param defaultConfigIds The defaults
*/
public void addDefaultIds(InferenceService.DefaultConfigId defaultConfigIds) {
var matched = idMatchedDefault(defaultConfigIds.inferenceId(), this.defaultConfigIds);
if (matched.isPresent()) {
throw new IllegalStateException(
"Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
+ defaultConfigIds.inferenceId()
+ "] declared by service ["
+ defaultConfigIds.service().name()
+ "]. The inference Id is already use by ["
+ matched.get().service().name()
+ "] service."
);
if (DefaultElserFeatureFlag.isEnabled()) {
var matched = idMatchedDefault(defaultConfigIds.inferenceId(), this.defaultConfigIds);
if (matched.isPresent()) {
throw new IllegalStateException(
"Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
+ defaultConfigIds.inferenceId()
+ "] declared by service ["
+ defaultConfigIds.service().name()
+ "]. The inference Id is already use by ["
+ matched.get().service().name()
+ "] service."
);
}
this.defaultConfigIds.add(defaultConfigIds);
} else {
logger.error("Attempted to addDefaultIds [{}] with the feature flag disabled", defaultConfigIds.inferenceId());
}
this.defaultConfigIds.add(defaultConfigIds);
}

/**
Expand All @@ -142,7 +147,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener<Unparse
// There should be a hit for the configurations
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
if (maybeDefault.isPresent()) {
if (DefaultElserFeatureFlag.isEnabled() && maybeDefault.isPresent()) {
getDefaultConfig(true, maybeDefault.get(), listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
Expand Down Expand Up @@ -173,7 +178,7 @@ public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> lis
// There should be a hit for the configurations
if (searchResponse.getHits().getHits().length == 0) {
var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
if (maybeDefault.isPresent()) {
if (DefaultElserFeatureFlag.isEnabled() && maybeDefault.isPresent()) {
getDefaultConfig(true, maybeDefault.get(), listener);
} else {
delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
Expand Down Expand Up @@ -209,8 +214,12 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt
public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedModel>> listener) {
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
if (DefaultElserFeatureFlag.isEnabled()) {
var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
} else {
delegate.onResponse(modelConfigs);
}
});

QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString()));
Expand Down Expand Up @@ -240,7 +249,11 @@ public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedM
public void getAllModels(boolean persistDefaultEndpoints, ActionListener<List<UnparsedModel>> listener) {
ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
if (DefaultElserFeatureFlag.isEnabled()) {
addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
} else {
delegate.onResponse(foundConfigs);
}
});

// In theory the index should only contain model config documents
Expand All @@ -264,26 +277,32 @@ private void addAllDefaultConfigsIfMissing(
List<InferenceService.DefaultConfigId> matchedDefaults,
ActionListener<List<UnparsedModel>> listener
) {
var foundIds = foundConfigs.stream().map(UnparsedModel::inferenceEntityId).collect(Collectors.toSet());
var missing = matchedDefaults.stream().filter(d -> foundIds.contains(d.inferenceId()) == false).toList();
if (DefaultElserFeatureFlag.isEnabled()) {

if (missing.isEmpty()) {
listener.onResponse(foundConfigs);
} else {
var groupedListener = new GroupedActionListener<UnparsedModel>(
missing.size(),
listener.delegateFailure((delegate, listOfModels) -> {
var allConfigs = new ArrayList<UnparsedModel>();
allConfigs.addAll(foundConfigs);
allConfigs.addAll(listOfModels);
allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
delegate.onResponse(allConfigs);
})
);
var foundIds = foundConfigs.stream().map(UnparsedModel::inferenceEntityId).collect(Collectors.toSet());
var missing = matchedDefaults.stream().filter(d -> foundIds.contains(d.inferenceId()) == false).toList();

if (missing.isEmpty()) {
listener.onResponse(foundConfigs);
} else {
var groupedListener = new GroupedActionListener<UnparsedModel>(
missing.size(),
listener.delegateFailure((delegate, listOfModels) -> {
var allConfigs = new ArrayList<UnparsedModel>();
allConfigs.addAll(foundConfigs);
allConfigs.addAll(listOfModels);
allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
delegate.onResponse(allConfigs);
})
);

for (var required : missing) {
getDefaultConfig(persistDefaultEndpoints, required, groupedListener);
for (var required : missing) {
getDefaultConfig(persistDefaultEndpoints, required, groupedListener);
}
}
} else {
logger.error("Attempted to add default configs with the feature flag disabled");
assert false;
}
}

Expand All @@ -292,40 +311,52 @@ private void getDefaultConfig(
InferenceService.DefaultConfigId defaultConfig,
ActionListener<UnparsedModel> listener
) {
defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
boolean foundModel = false;
for (var m : models) {
if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
foundModel = true;
if (persistDefaultEndpoints) {
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
} else {
listener.onResponse(modelToUnparsedModel(m));
if (DefaultElserFeatureFlag.isEnabled()) {

defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
boolean foundModel = false;
for (var m : models) {
if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
foundModel = true;
if (persistDefaultEndpoints) {
storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
} else {
listener.onResponse(modelToUnparsedModel(m));
}
break;
}
break;
}
}

if (foundModel == false) {
listener.onFailure(
new IllegalStateException("Configuration not found for default inference id [" + defaultConfig.inferenceId() + "]")
);
}
}));
if (foundModel == false) {
listener.onFailure(
new IllegalStateException("Configuration not found for default inference id [" + defaultConfig.inferenceId() + "]")
);
}
}));
} else {
logger.error("Attempted to get default configs with the feature flag disabled");
assert false;
}
}

private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
var responseListener = ActionListener.<Boolean>wrap(success -> {
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
}, exception -> {
if (exception instanceof ResourceAlreadyExistsException) {
logger.debug("Default inference id [{}] already exists", preconfigured.getInferenceEntityId());
} else {
logger.error("Failed to store default inference id [" + preconfigured.getInferenceEntityId() + "]", exception);
}
});
if (DefaultElserFeatureFlag.isEnabled()) {

var responseListener = ActionListener.<Boolean>wrap(success -> {
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
}, exception -> {
if (exception instanceof ResourceAlreadyExistsException) {
logger.debug("Default inference id [{}] already exists", preconfigured.getInferenceEntityId());
} else {
logger.error("Failed to store default inference id [" + preconfigured.getInferenceEntityId() + "]", exception);
}
});

storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter));
storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter));
} else {
logger.error("Attempted to store default endpoint with the feature flag disabled");
assert false;
}
}

private ArrayList<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
Expand Down Expand Up @@ -673,6 +704,7 @@ static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
TaskType taskType,
List<InferenceService.DefaultConfigId> defaultConfigIds
) {
assert DefaultElserFeatureFlag.isEnabled();
return defaultConfigIds.stream()
.filter(defaultConfigId -> defaultConfigId.taskType().equals(taskType))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
Expand Down Expand Up @@ -113,7 +114,7 @@ public void parseRequestConfig(
Map<String, Object> config,
ActionListener<Model> modelListener
) {
if (inferenceEntityId.equals(DEFAULT_ELSER_ID)) {
if (DefaultElserFeatureFlag.isEnabled() && inferenceEntityId.equals(DEFAULT_ELSER_ID)) {
modelListener.onFailure(
new ElasticsearchStatusException(
"[{}] is a reserved inference Id. Cannot create a new inference endpoint with a reserved Id",
Expand Down Expand Up @@ -769,6 +770,8 @@ private RankedDocsResults textSimilarityResultsToRankedDocs(
}

public List<DefaultConfigId> defaultConfigIds() {
assert DefaultElserFeatureFlag.isEnabled();

return List.of(
new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
Expand Down Expand Up @@ -817,13 +820,18 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
}

public void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {
defaultsListener.onResponse(defaultConfigsLinuxOptimized());
} else {
defaultsListener.onResponse(defaultConfigsPlatfromAgnostic());
}
}));
if (DefaultElserFeatureFlag.isEnabled()) {
preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {
defaultsListener.onResponse(defaultConfigsLinuxOptimized());
} else {
defaultsListener.onResponse(defaultConfigsPlatfromAgnostic());
}
}));
} else {
logger.error("Attempted to add default configs with the feature flag disabled");
assert false;
}
}

private List<Model> defaultConfigsLinuxOptimized() {
Expand Down Expand Up @@ -865,6 +873,7 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {

@Override
boolean isDefaultId(String inferenceId) {
assert DefaultElserFeatureFlag.isEnabled();
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
}

Expand Down
Loading
Loading