diff --git a/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/AbstractIndexCompatibilityTestCase.java b/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/AbstractIndexCompatibilityTestCase.java index a271c76bbc96e..d1a44235a3ced 100644 --- a/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/AbstractIndexCompatibilityTestCase.java +++ b/qa/lucene-index-compatibility/src/javaRestTest/java/org/elasticsearch/lucene/AbstractIndexCompatibilityTestCase.java @@ -26,9 +26,7 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.test.XContentTestUtils; import org.elasticsearch.test.cluster.ElasticsearchCluster; -import org.elasticsearch.test.cluster.local.DefaultSettingsProvider; import org.elasticsearch.test.cluster.local.LocalClusterConfigProvider; -import org.elasticsearch.test.cluster.local.LocalClusterSpec; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.cluster.util.Version; import org.elasticsearch.test.rest.ESRestTestCase; @@ -80,16 +78,6 @@ public abstract class AbstractIndexCompatibilityTestCase extends ESRestTestCase .setting("xpack.security.enabled", "false") .setting("xpack.ml.enabled", "false") .setting("path.repo", () -> REPOSITORY_PATH.getRoot().getPath()) - .settings(new DefaultSettingsProvider() { - @Override - public Map get(LocalClusterSpec.LocalNodeSpec nodeSpec) { - var settings = super.get(nodeSpec); - if (nodeSpec.getVersion().onOrAfter(Version.fromString("9.2.0"))) { - settings.put("xpack.inference.endpoint.cache.enabled", "false"); - } - return settings; - } - }) .apply(() -> clusterConfig) .build(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 1c4543f170c9b..3ade24c89c82e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -50,6 +50,13 @@ public class InferenceFeatures implements FeatureSpecification { public static final NodeFeature SEMANTIC_TEXT_HIGHLIGHTING_FLAT = new NodeFeature("semantic_text.highlighter.flat_index_options"); private static final NodeFeature SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT = new NodeFeature("semantic_text.fields_chunks_format"); + public static final NodeFeature INFERENCE_ENDPOINT_CACHE = new NodeFeature("inference.endpoint.cache"); + + @Override + public Set getFeatures() { + return Set.of(INFERENCE_ENDPOINT_CACHE); + } + @Override public Set getTestFeatures() { var testFeatures = new HashSet<>( @@ -90,6 +97,7 @@ public Set getTestFeatures() { TEXT_SIMILARITY_RERANKER_SNIPPETS ) ); + testFeatures.addAll(getFeatures()); return testFeatures; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index e4d66b92d5274..415b93443db5f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -391,7 +391,8 @@ public Collection createComponents(PluginServices services) { settings, modelRegistry.get(), serviceRegistry, - services.projectResolver() + services.projectResolver(), + services.featureService() ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java index 46d93e8d404b7..e9d1705f50d13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/InferenceEndpointRegistry.java @@ -19,8 +19,10 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.features.FeatureService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; +import org.elasticsearch.xpack.inference.InferenceFeatures; import java.util.Collection; import java.util.List; @@ -64,14 +66,17 @@ public static Collection> getSettingsDefinitions() { private final InferenceServiceRegistry serviceRegistry; private final ProjectResolver projectResolver; private final Cache cache; - private volatile boolean cacheEnabled; + private final ClusterService clusterService; + private final FeatureService featureService; + private volatile boolean cacheEnabledViaSetting; public InferenceEndpointRegistry( ClusterService clusterService, Settings settings, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, - ProjectResolver projectResolver + ProjectResolver projectResolver, + FeatureService featureService ) { this.modelRegistry = modelRegistry; this.serviceRegistry = serviceRegistry; @@ -80,15 +85,17 @@ public InferenceEndpointRegistry( .setMaximumWeight(INFERENCE_ENDPOINT_CACHE_WEIGHT.get(settings)) .setExpireAfterWrite(INFERENCE_ENDPOINT_CACHE_EXPIRY.get(settings)) .build(); - this.cacheEnabled = INFERENCE_ENDPOINT_CACHE_ENABLED.get(settings); + this.clusterService = clusterService; + this.featureService = featureService; + this.cacheEnabledViaSetting = INFERENCE_ENDPOINT_CACHE_ENABLED.get(settings); clusterService.getClusterSettings() - .addSettingsUpdateConsumer(INFERENCE_ENDPOINT_CACHE_ENABLED, enabled -> this.cacheEnabled = enabled); + .addSettingsUpdateConsumer(INFERENCE_ENDPOINT_CACHE_ENABLED, enabled -> this.cacheEnabledViaSetting = enabled); } public void getEndpoint(String inferenceEntityId, ActionListener listener) { var key = new InferenceIdAndProject(inferenceEntityId, projectResolver.getProjectId()); - var cachedModel = cacheEnabled ? cache.get(key) : null; + var cachedModel = cacheEnabled() ? cache.get(key) : null; if (cachedModel != null) { log.trace("Retrieved [{}] from cache.", inferenceEntityId); listener.onResponse(cachedModel); @@ -98,7 +105,7 @@ public void getEndpoint(String inferenceEntityId, ActionListener listener } void invalidateAll(ProjectId projectId) { - if (cacheEnabled) { + if (cacheEnabled()) { var cacheKeys = cache.keys().iterator(); while (cacheKeys.hasNext()) { if (cacheKeys.next().projectId.equals(projectId)) { @@ -126,7 +133,7 @@ private void loadFromIndex(InferenceIdAndProject idAndProject, ActionListener