Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.test.XContentTestUtils;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.DefaultSettingsProvider;
import org.elasticsearch.test.cluster.local.LocalClusterConfigProvider;
import org.elasticsearch.test.cluster.local.LocalClusterSpec;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.cluster.util.Version;
import org.elasticsearch.test.rest.ESRestTestCase;
Expand Down Expand Up @@ -80,16 +78,6 @@ public abstract class AbstractIndexCompatibilityTestCase extends ESRestTestCase
.setting("xpack.security.enabled", "false")
.setting("xpack.ml.enabled", "false")
.setting("path.repo", () -> REPOSITORY_PATH.getRoot().getPath())
.settings(new DefaultSettingsProvider() {
@Override
public Map<String, String> get(LocalClusterSpec.LocalNodeSpec nodeSpec) {
var settings = super.get(nodeSpec);
if (nodeSpec.getVersion().onOrAfter(Version.fromString("9.2.0"))) {
settings.put("xpack.inference.endpoint.cache.enabled", "false");
}
return settings;
}
})
.apply(() -> clusterConfig)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ public class InferenceFeatures implements FeatureSpecification {
public static final NodeFeature SEMANTIC_TEXT_HIGHLIGHTING_FLAT = new NodeFeature("semantic_text.highlighter.flat_index_options");
private static final NodeFeature SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT = new NodeFeature("semantic_text.fields_chunks_format");

public static final NodeFeature INFERENCE_ENDPOINT_CACHE = new NodeFeature("inference.endpoint.cache");

@Override
public Set<NodeFeature> getFeatures() {
return Set.of(INFERENCE_ENDPOINT_CACHE);
}

@Override
public Set<NodeFeature> getTestFeatures() {
var testFeatures = new HashSet<>(
Expand Down Expand Up @@ -90,6 +97,7 @@ public Set<NodeFeature> getTestFeatures() {
TEXT_SIMILARITY_RERANKER_SNIPPETS
)
);
testFeatures.addAll(getFeatures());
return testFeatures;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ public Collection<?> createComponents(PluginServices services) {
settings,
modelRegistry.get(),
serviceRegistry,
services.projectResolver()
services.projectResolver(),
services.featureService()
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.xpack.inference.InferenceFeatures;

import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -64,14 +66,17 @@ public static Collection<? extends Setting<?>> getSettingsDefinitions() {
private final InferenceServiceRegistry serviceRegistry;
private final ProjectResolver projectResolver;
private final Cache<InferenceIdAndProject, Model> cache;
private volatile boolean cacheEnabled;
private final ClusterService clusterService;
private final FeatureService featureService;
private volatile boolean cacheEnabledViaSetting;

public InferenceEndpointRegistry(
ClusterService clusterService,
Settings settings,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
ProjectResolver projectResolver
ProjectResolver projectResolver,
FeatureService featureService
) {
this.modelRegistry = modelRegistry;
this.serviceRegistry = serviceRegistry;
Expand All @@ -80,15 +85,17 @@ public InferenceEndpointRegistry(
.setMaximumWeight(INFERENCE_ENDPOINT_CACHE_WEIGHT.get(settings))
.setExpireAfterWrite(INFERENCE_ENDPOINT_CACHE_EXPIRY.get(settings))
.build();
this.cacheEnabled = INFERENCE_ENDPOINT_CACHE_ENABLED.get(settings);
this.clusterService = clusterService;
this.featureService = featureService;
this.cacheEnabledViaSetting = INFERENCE_ENDPOINT_CACHE_ENABLED.get(settings);

clusterService.getClusterSettings()
.addSettingsUpdateConsumer(INFERENCE_ENDPOINT_CACHE_ENABLED, enabled -> this.cacheEnabled = enabled);
.addSettingsUpdateConsumer(INFERENCE_ENDPOINT_CACHE_ENABLED, enabled -> this.cacheEnabledViaSetting = enabled);
}

public void getEndpoint(String inferenceEntityId, ActionListener<Model> listener) {
var key = new InferenceIdAndProject(inferenceEntityId, projectResolver.getProjectId());
var cachedModel = cacheEnabled ? cache.get(key) : null;
var cachedModel = cacheEnabled() ? cache.get(key) : null;
if (cachedModel != null) {
log.trace("Retrieved [{}] from cache.", inferenceEntityId);
listener.onResponse(cachedModel);
Expand All @@ -98,7 +105,7 @@ public void getEndpoint(String inferenceEntityId, ActionListener<Model> listener
}

void invalidateAll(ProjectId projectId) {
if (cacheEnabled) {
if (cacheEnabled()) {
var cacheKeys = cache.keys().iterator();
while (cacheKeys.hasNext()) {
if (cacheKeys.next().projectId.equals(projectId)) {
Expand Down Expand Up @@ -126,23 +133,28 @@ private void loadFromIndex(InferenceIdAndProject idAndProject, ActionListener<Mo
unparsedModel.secrets()
);

if (cacheEnabled) {
if (cacheEnabled()) {
cache.put(idAndProject, model);
}
l.onResponse(model);
}));
}

public Cache.Stats stats() {
return cacheEnabled ? cache.stats() : EMPTY;
return cacheEnabled() ? cache.stats() : EMPTY;
}

public int cacheCount() {
return cacheEnabled ? cache.count() : 0;
return cacheEnabled() ? cache.count() : 0;
}

public boolean cacheEnabled() {
return cacheEnabled;
return cacheEnabledViaSetting && cacheEnabledViaFeature();
}

private boolean cacheEnabledViaFeature() {
var state = clusterService.state();
return state.clusterRecovered() && featureService.clusterHasFeature(state, InferenceFeatures.INFERENCE_ENDPOINT_CACHE);
}

private record InferenceIdAndProject(String inferenceEntityId, ProjectId projectId) {}
Expand Down