diff --git a/docs/changelog/123602.yaml b/docs/changelog/123602.yaml new file mode 100644 index 0000000000000..c0d04e640c83c --- /dev/null +++ b/docs/changelog/123602.yaml @@ -0,0 +1,5 @@ +pr: 123602 +summary: Make `CachingUsernamePasswordRealm` and `ReloadablePlugin` Project Aware +area: Infra/Plugins +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index ee21ffdc2d48b..100944c3e055d 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -50,6 +50,7 @@ import org.elasticsearch.cluster.metadata.MetadataDataStreamsService; import org.elasticsearch.cluster.metadata.MetadataIndexTemplateService; import org.elasticsearch.cluster.metadata.MetadataUpdateSettingsService; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.metadata.SystemIndexMetadataUpgradeService; import org.elasticsearch.cluster.metadata.TemplateUpgradeService; @@ -1498,12 +1499,26 @@ private CircuitBreakerService createCircuitBreakerService( * @return A single ReloadablePlugin that, upon reload, reloads the plugins it wraps */ private static ReloadablePlugin wrapPlugins(List reloadablePlugins) { - return settings -> { - for (ReloadablePlugin plugin : reloadablePlugins) { - try { - plugin.reload(settings); - } catch (IOException e) { - throw new UncheckedIOException(e); + return new ReloadablePlugin() { + @Override + public void reload(Settings settings) throws Exception { + for (ReloadablePlugin plugin : reloadablePlugins) { + try { + plugin.reload(settings); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } + + @Override + public void reload(ProjectId projectId, Settings settings) throws Exception { + for (ReloadablePlugin plugin : reloadablePlugins) { + try { + plugin.reload(projectId, settings); + } catch (IOException e) { + throw new UncheckedIOException(e); + } } } }; diff --git a/server/src/main/java/org/elasticsearch/plugins/ReloadablePlugin.java b/server/src/main/java/org/elasticsearch/plugins/ReloadablePlugin.java index a20b6af3d580c..80c9260fa184f 100644 --- a/server/src/main/java/org/elasticsearch/plugins/ReloadablePlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/ReloadablePlugin.java @@ -9,6 +9,7 @@ package org.elasticsearch.plugins; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.settings.Settings; /** @@ -41,4 +42,6 @@ public interface ReloadablePlugin { * if the offending call didn't happen. */ void reload(Settings settings) throws Exception; + + default void reload(ProjectId projectId, Settings settings) throws Exception {} } diff --git a/server/src/test/java/org/elasticsearch/plugins/internal/ReloadAwarePluginTests.java b/server/src/test/java/org/elasticsearch/plugins/internal/ReloadAwarePluginTests.java index 96e45edb8e01e..eb65cc3892569 100644 --- a/server/src/test/java/org/elasticsearch/plugins/internal/ReloadAwarePluginTests.java +++ b/server/src/test/java/org/elasticsearch/plugins/internal/ReloadAwarePluginTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.plugins.internal; +import org.elasticsearch.cluster.metadata.ProjectId; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.node.MockNode; import org.elasticsearch.plugins.Plugin; @@ -57,20 +58,26 @@ public void setReloadCallback(ReloadablePlugin reloadablePlugin) { public void invokeReloadOperation() throws Exception { reloadablePlugin.reload(Settings.EMPTY); + reloadablePlugin.reload(randomUniqueProjectId(), Settings.EMPTY); } } public static class TestReloadablePlugin extends Plugin implements ReloadablePlugin { - private boolean reloaded = false; + private int reloadCount = 0; @Override public void reload(Settings settings) throws Exception { - reloaded = true; + reloadCount++; + } + + @Override + public void reload(ProjectId projectId, Settings settings) throws Exception { + reloadCount++; } public boolean isReloaded() { - return reloaded; + return reloadCount >= 2; } } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java index 287f6035b74d5..8d772b1b0f69b 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.cache.CacheLoader; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.util.concurrent.ListenableFuture; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -32,25 +33,62 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm implements CachingRealm { - private final Cache> cache; + private final UserAuthenticationCache cache; private final ThreadPool threadPool; private final boolean authenticationEnabled; - final Hasher cacheHasher; + protected final Hasher cacheHasher; protected CachingUsernamePasswordRealm(RealmConfig config, ThreadPool threadPool) { + this(config, threadPool, buildDefaultCache(config)); + } + + protected CachingUsernamePasswordRealm(RealmConfig config, ThreadPool threadPool, UserAuthenticationCache cache) { super(config); cacheHasher = Hasher.resolve(this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_HASH_ALGO_SETTING)); this.threadPool = threadPool; - final TimeValue ttl = this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING); + this.authenticationEnabled = config.getSetting(CachingUsernamePasswordRealmSettings.AUTHC_ENABLED_SETTING); + this.cache = cache; + } + + private static UserAuthenticationCache buildDefaultCache(RealmConfig config) { + final TimeValue ttl = config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING); if (ttl.getNanos() > 0) { - cache = CacheBuilder.>builder() + final Cache> cache = CacheBuilder.>builder() .setExpireAfterWrite(ttl) - .setMaximumWeight(this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING)) + .setMaximumWeight(config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING)) .build(); - } else { - cache = null; + return new UserAuthenticationCache() { + @Override + public void invalidate(String key) { + cache.invalidate(key); + } + + @Override + public void invalidate(String key, ListenableFuture value) { + cache.invalidate(key, value); + } + + @Override + public void invalidateAll() { + cache.invalidateAll(); + } + + @Override + public int count() { + return cache.count(); + } + + @Override + public ListenableFuture computeIfAbsent( + String key, + CacheLoader> loader + ) throws ExecutionException { + return cache.computeIfAbsent(key, loader); + } + }; } - this.authenticationEnabled = config.getSetting(CachingUsernamePasswordRealmSettings.AUTHC_ENABLED_SETTING); + + return null; } @Override @@ -311,7 +349,20 @@ private void lookupWithCache(String username, ActionListener listener) { protected abstract void doLookupUser(String username, ActionListener listener); - private static class CachedResult { + protected interface UserAuthenticationCache { + void invalidate(String key); + + void invalidate(String key, ListenableFuture value); + + void invalidateAll(); + + int count(); + + ListenableFuture computeIfAbsent(String key, CacheLoader> loader) + throws ExecutionException; + } + + protected static class CachedResult { private final AuthenticationResult authenticationResult; private final User user; private final char[] hash; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/ProjectScopedCache.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/ProjectScopedCache.java new file mode 100644 index 0000000000000..e15961210c4f1 --- /dev/null +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/support/ProjectScopedCache.java @@ -0,0 +1,181 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.support; + +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.cache.CacheLoader; +import org.elasticsearch.common.cache.RemovalListener; +import org.elasticsearch.common.cache.RemovalNotification; +import org.elasticsearch.common.util.concurrent.ReleasableLock; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper; + +import java.util.Objects; +import java.util.concurrent.ExecutionException; +import java.util.function.ToLongBiFunction; + +/** + * Wrapper around a {@link Cache} instance where a composite key of the original key and the current project id, resolved through a + * {@link ProjectResolver}, is used to write to and read from the cache. + *

+ * During invalidation the cache is protected through locking in the {@link CacheIteratorHelper} because the result of iteration under any + * mutation other than Cache.CacheIterator#remove() is undefined. Concurrent writes are allowed as long as there is no active + * invalidation, the cache is protected against that by acquiring a read lock (blocking if invalidation in progress) before writing. + * + * @param key type + * @param value type + */ +public class ProjectScopedCache { + private final Cache, V> cache; + private final ProjectResolver projectResolver; + private final CacheIteratorHelper, V> cacheIteratorHelper; + + // Visible for testing + ProjectScopedCache(ProjectResolver projectResolver, Cache, V> cache) { + this.projectResolver = projectResolver; + this.cache = cache; + cacheIteratorHelper = new CacheIteratorHelper<>(cache); + } + + public void invalidateProject() { + if (projectResolver.supportsMultipleProjects()) { + invalidateProject(projectResolver.getProjectId()); + } else { + cache.invalidateAll(); + } + } + + public void invalidateProject(ProjectId projectId) { + cacheIteratorHelper.removeKeysIf(key -> key.projectId().equals(projectId)); + } + + public void invalidate(K key) { + invalidate(projectResolver.getProjectId(), key); + } + + public void invalidate(ProjectId projectId, K key) { + try (ReleasableLock ignored = cacheIteratorHelper.acquireUpdateLock()) { + cache.invalidate(new ProjectScoped<>(projectId, key)); + } + } + + public void invalidate(K key, V value) { + try (ReleasableLock ignored = cacheIteratorHelper.acquireUpdateLock()) { + cache.invalidate(new ProjectScoped<>(projectResolver.getProjectId(), key), value); + } + } + + public void invalidateAll() { + try (ReleasableLock ignored = cacheIteratorHelper.acquireUpdateLock()) { + cache.invalidateAll(); + } + } + + public V computeIfAbsent(K key, CacheLoader loader) throws ExecutionException { + try (var ignored = cacheIteratorHelper.acquireUpdateLock()) { + return cache.computeIfAbsent(new ProjectScoped<>(projectResolver.getProjectId(), key), k -> loader.load(k.value)); + } + } + + public long weight() { + return cache.weight(); + } + + public int count() { + return cache.count(); + } + + public static Builder builder() { + return new Builder<>(); + } + + public static class Builder { + private long maximumWeight; + private TimeValue expireAfterAccessNanos; + private TimeValue expireAfterWrite; + private ToLongBiFunction weigher; + private RemovalListener removalListener; + + private Builder() {} + + public Builder setMaximumWeight(long maximumWeight) { + if (maximumWeight < 0) { + throw new IllegalArgumentException("maximumWeight < 0"); + } + this.maximumWeight = maximumWeight; + return this; + } + + public Builder setExpireAfterAccess(TimeValue expireAfterAccess) { + Objects.requireNonNull(expireAfterAccess); + final long expireAfterAccessNanos = expireAfterAccess.getNanos(); + if (expireAfterAccessNanos <= 0) { + throw new IllegalArgumentException("expireAfterAccess <= 0"); + } + this.expireAfterAccessNanos = expireAfterAccess; + return this; + } + + public Builder setExpireAfterWrite(TimeValue expireAfterWrite) { + Objects.requireNonNull(expireAfterWrite); + final long expireAfterWriteNanos = expireAfterWrite.getNanos(); + if (expireAfterWriteNanos <= 0) { + throw new IllegalArgumentException("expireAfterWrite <= 0"); + } + this.expireAfterWrite = expireAfterWrite; + return this; + } + + public Builder weigher(ToLongBiFunction weigher) { + Objects.requireNonNull(weigher); + this.weigher = weigher; + return this; + } + + public Builder removalListener(RemovalListener removalListener) { + Objects.requireNonNull(removalListener); + this.removalListener = removalListener; + return this; + } + + public ProjectScopedCache build(ProjectResolver projectResolver) { + CacheBuilder, V> cacheBuilder = CacheBuilder.builder(); + + if (maximumWeight != -1) { + cacheBuilder.setMaximumWeight(maximumWeight); + } + if (expireAfterAccessNanos != null) { + cacheBuilder.setExpireAfterAccess(expireAfterAccessNanos); + } + if (expireAfterWrite != null) { + cacheBuilder.setExpireAfterWrite(expireAfterWrite); + } + if (weigher != null) { + cacheBuilder.weigher((key, value) -> weigher.applyAsLong(key.value, value)); + } + if (removalListener != null) { + cacheBuilder.removalListener((notification) -> { + removalListener.onRemoval( + new RemovalNotification<>(notification.getKey().value, notification.getValue(), notification.getRemovalReason()) + ); + }); + } + return new ProjectScopedCache<>(projectResolver, cacheBuilder.build()); + } + } + + private record ProjectScoped(ProjectId projectId, T value) { + private ProjectScoped(ProjectId projectId, T value) { + this.projectId = Objects.requireNonNull(projectId); + this.value = value; + } + } +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/support/ProjectScopedCacheTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/support/ProjectScopedCacheTests.java new file mode 100644 index 0000000000000..950be3109b7f5 --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/support/ProjectScopedCacheTests.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.support; + +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.project.ProjectResolver; +import org.elasticsearch.common.cache.CacheLoader; +import org.elasticsearch.core.CheckedRunnable; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.cluster.metadata.ProjectId.fromId; +import static org.hamcrest.Matchers.equalTo; + +public class ProjectScopedCacheTests extends ESTestCase { + private final AtomicReference activeProjectId = new AtomicReference<>(); + private ProjectScopedCache projectScopedCache; + private CacheLoader loader; + private final ProjectResolver projectResolver = new ProjectResolver() { + @Override + public void executeOnProject(ProjectId projectId, CheckedRunnable body) { + throw new UnsupportedOperationException(); + } + + @Override + public ProjectId getProjectId() { + return activeProjectId.get(); + } + + @Override + public boolean supportsMultipleProjects() { + return true; + } + }; + + @Before + public void initCache() { + activeProjectId.set(fromId("test-project")); + projectScopedCache = ProjectScopedCache.builder().setMaximumWeight(100).build(projectResolver); + loader = key -> "value-" + key; + } + + public void testComputeIfAbsent() throws ExecutionException { + activeProjectId.set(fromId("test-project")); + String value = projectScopedCache.computeIfAbsent("key1", loader); + assertThat(value, equalTo("value-key1")); + assertThat(projectScopedCache.count(), equalTo(1)); + } + + public void testInvalidateKey() throws ExecutionException { + activeProjectId.set(fromId("test-project")); + String value = projectScopedCache.computeIfAbsent("key1", loader); + assertThat(value, equalTo("value-key1")); + assertThat(projectScopedCache.count(), equalTo(1)); + + projectScopedCache.invalidate("key1"); + assertThat(projectScopedCache.count(), equalTo(0)); + } + + public void testInvalidateProject() throws ExecutionException { + activeProjectId.set(fromId("test-project")); + projectScopedCache.computeIfAbsent("key1", loader); + activeProjectId.set(fromId("other-test-project")); + projectScopedCache.computeIfAbsent("key1", loader); + + assertThat(projectScopedCache.count(), equalTo(2)); + projectScopedCache.invalidateProject(); + assertThat(projectScopedCache.count(), equalTo(1)); + + activeProjectId.set(fromId("test-project")); + projectScopedCache.invalidateProject(); + assertThat(projectScopedCache.count(), equalTo(0)); + } + + public void testInvalidateAll() throws ExecutionException { + activeProjectId.set(fromId("test-project")); + projectScopedCache.computeIfAbsent("key1", loader); + activeProjectId.set(fromId("other-test-project")); + projectScopedCache.computeIfAbsent("key1", loader); + projectScopedCache.invalidateAll(); + assertThat(projectScopedCache.count(), equalTo(0)); + } + + public void testRemovalListener() throws ExecutionException { + final AtomicReference removedKey = new AtomicReference<>(); + + projectScopedCache = ProjectScopedCache.builder() + .setMaximumWeight(100) + .removalListener((notification) -> removedKey.set(notification.getKey())) + .build(projectResolver); + + activeProjectId.set(fromId("test-project")); + projectScopedCache.computeIfAbsent("key1", loader); + projectScopedCache.invalidate("key1"); + assertThat(removedKey.get(), equalTo("key1")); + } + + public void testWeigher() throws ExecutionException { + int numberOfEntries = randomIntBetween(2, 10); + int maximumWeight = 2 * numberOfEntries; + int weight = randomIntBetween(2, 10); + AtomicLong evictions = new AtomicLong(); + + projectScopedCache = ProjectScopedCache.builder() + .setMaximumWeight(maximumWeight) + .weigher((k, v) -> weight) + .removalListener(notification -> evictions.incrementAndGet()) + .build(projectResolver); + + for (int i = 0; i < numberOfEntries; i++) { + projectScopedCache.computeIfAbsent(Integer.toString(i), loader); + } + // cache weight should be the largest multiple of weight less than maximumWeight + assertEquals(weight * (maximumWeight / weight), projectScopedCache.weight()); + + // the number of evicted entries should be the number of entries that fit in the excess weight + assertEquals((int) Math.ceil((weight - 2) * numberOfEntries / (1.0 * weight)), evictions.get()); + } +}