Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/123602.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 123602
summary: Make `CachingUsernamePasswordRealm` and `ReloadablePlugin` Project Aware
area: Infra/Plugins
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.elasticsearch.cluster.metadata.MetadataDataStreamsService;
import org.elasticsearch.cluster.metadata.MetadataIndexTemplateService;
import org.elasticsearch.cluster.metadata.MetadataUpdateSettingsService;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.metadata.SystemIndexMetadataUpgradeService;
import org.elasticsearch.cluster.metadata.TemplateUpgradeService;
Expand Down Expand Up @@ -1498,12 +1499,26 @@ private CircuitBreakerService createCircuitBreakerService(
* @return A single ReloadablePlugin that, upon reload, reloads the plugins it wraps
*/
private static ReloadablePlugin wrapPlugins(List<ReloadablePlugin> reloadablePlugins) {
return settings -> {
for (ReloadablePlugin plugin : reloadablePlugins) {
try {
plugin.reload(settings);
} catch (IOException e) {
throw new UncheckedIOException(e);
return new ReloadablePlugin() {
@Override
public void reload(Settings settings) throws Exception {
for (ReloadablePlugin plugin : reloadablePlugins) {
try {
plugin.reload(settings);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}

@Override
public void reload(ProjectId projectId, Settings settings) throws Exception {
for (ReloadablePlugin plugin : reloadablePlugins) {
try {
plugin.reload(projectId, settings);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.plugins;

import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.common.settings.Settings;

/**
Expand Down Expand Up @@ -41,4 +42,6 @@ public interface ReloadablePlugin {
* if the offending call didn't happen.
*/
void reload(Settings settings) throws Exception;

default void reload(ProjectId projectId, Settings settings) throws Exception {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

package org.elasticsearch.plugins.internal;

import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.node.MockNode;
import org.elasticsearch.plugins.Plugin;
Expand Down Expand Up @@ -57,20 +58,26 @@ public void setReloadCallback(ReloadablePlugin reloadablePlugin) {

public void invokeReloadOperation() throws Exception {
reloadablePlugin.reload(Settings.EMPTY);
reloadablePlugin.reload(randomUniqueProjectId(), Settings.EMPTY);
}
}

public static class TestReloadablePlugin extends Plugin implements ReloadablePlugin {

private boolean reloaded = false;
private int reloadCount = 0;

@Override
public void reload(Settings settings) throws Exception {
reloaded = true;
reloadCount++;
}

@Override
public void reload(ProjectId projectId, Settings settings) throws Exception {
reloadCount++;
}

public boolean isReloaded() {
return reloaded;
return reloadCount >= 2;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.common.util.concurrent.ThreadContext;
Expand All @@ -32,25 +33,62 @@

public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm implements CachingRealm {

private final Cache<String, ListenableFuture<CachedResult>> cache;
private final UserAuthenticationCache cache;
private final ThreadPool threadPool;
private final boolean authenticationEnabled;
final Hasher cacheHasher;
protected final Hasher cacheHasher;

protected CachingUsernamePasswordRealm(RealmConfig config, ThreadPool threadPool) {
this(config, threadPool, buildDefaultCache(config));
}

protected CachingUsernamePasswordRealm(RealmConfig config, ThreadPool threadPool, UserAuthenticationCache cache) {
super(config);
cacheHasher = Hasher.resolve(this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_HASH_ALGO_SETTING));
this.threadPool = threadPool;
final TimeValue ttl = this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING);
this.authenticationEnabled = config.getSetting(CachingUsernamePasswordRealmSettings.AUTHC_ENABLED_SETTING);
this.cache = cache;
}

private static UserAuthenticationCache buildDefaultCache(RealmConfig config) {
final TimeValue ttl = config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING);
if (ttl.getNanos() > 0) {
cache = CacheBuilder.<String, ListenableFuture<CachedResult>>builder()
final Cache<String, ListenableFuture<CachedResult>> cache = CacheBuilder.<String, ListenableFuture<CachedResult>>builder()
.setExpireAfterWrite(ttl)
.setMaximumWeight(this.config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING))
.setMaximumWeight(config.getSetting(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING))
.build();
} else {
cache = null;
return new UserAuthenticationCache() {
@Override
public void invalidate(String key) {
cache.invalidate(key);
}

@Override
public void invalidate(String key, ListenableFuture<CachedResult> value) {
cache.invalidate(key, value);
}

@Override
public void invalidateAll() {
cache.invalidateAll();
}

@Override
public int count() {
return cache.count();
}

@Override
public ListenableFuture<CachedResult> computeIfAbsent(
String key,
CacheLoader<String, ListenableFuture<CachedResult>> loader
) throws ExecutionException {
return cache.computeIfAbsent(key, loader);
}
};
}
this.authenticationEnabled = config.getSetting(CachingUsernamePasswordRealmSettings.AUTHC_ENABLED_SETTING);

return null;
}

@Override
Expand Down Expand Up @@ -311,7 +349,20 @@ private void lookupWithCache(String username, ActionListener<User> listener) {

protected abstract void doLookupUser(String username, ActionListener<User> listener);

private static class CachedResult {
protected interface UserAuthenticationCache {
void invalidate(String key);

void invalidate(String key, ListenableFuture<CachedResult> value);

void invalidateAll();

int count();

ListenableFuture<CachedResult> computeIfAbsent(String key, CacheLoader<String, ListenableFuture<CachedResult>> loader)
throws ExecutionException;
}

protected static class CachedResult {
private final AuthenticationResult<User> authenticationResult;
private final User user;
private final char[] hash;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.security.support;

import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.cache.RemovalListener;
import org.elasticsearch.common.cache.RemovalNotification;
import org.elasticsearch.common.util.concurrent.ReleasableLock;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper;

import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.function.ToLongBiFunction;

/**
* Wrapper around a {@link Cache} instance where a composite key of the original key and the current project id, resolved through a
* {@link ProjectResolver}, is used to write to and read from the cache.
* <p>
* During invalidation the cache is protected through locking in the {@link CacheIteratorHelper} because the result of iteration under any
* mutation other than Cache.CacheIterator#remove() is undefined. Concurrent writes are allowed as long as there is no active
* invalidation, the cache is protected against that by acquiring a read lock (blocking if invalidation in progress) before writing.
*
* @param <K> key type
* @param <V> value type
*/
public class ProjectScopedCache<K, V> {
private final Cache<ProjectScoped<K>, V> cache;
private final ProjectResolver projectResolver;
private final CacheIteratorHelper<ProjectScoped<K>, V> cacheIteratorHelper;

// Visible for testing
ProjectScopedCache(ProjectResolver projectResolver, Cache<ProjectScoped<K>, V> cache) {
this.projectResolver = projectResolver;
this.cache = cache;
cacheIteratorHelper = new CacheIteratorHelper<>(cache);
}

public void invalidateProject() {
if (projectResolver.supportsMultipleProjects()) {
invalidateProject(projectResolver.getProjectId());
} else {
cache.invalidateAll();
}
}

public void invalidateProject(ProjectId projectId) {
cacheIteratorHelper.removeKeysIf(key -> key.projectId().equals(projectId));
}

public void invalidate(K key) {
invalidate(projectResolver.getProjectId(), key);
}

public void invalidate(ProjectId projectId, K key) {
try (ReleasableLock ignored = cacheIteratorHelper.acquireUpdateLock()) {
cache.invalidate(new ProjectScoped<>(projectId, key));
}
}

public void invalidate(K key, V value) {
try (ReleasableLock ignored = cacheIteratorHelper.acquireUpdateLock()) {
cache.invalidate(new ProjectScoped<>(projectResolver.getProjectId(), key), value);
}
}

public void invalidateAll() {
try (ReleasableLock ignored = cacheIteratorHelper.acquireUpdateLock()) {
cache.invalidateAll();
}
}

public V computeIfAbsent(K key, CacheLoader<K, V> loader) throws ExecutionException {
try (var ignored = cacheIteratorHelper.acquireUpdateLock()) {
return cache.computeIfAbsent(new ProjectScoped<>(projectResolver.getProjectId(), key), k -> loader.load(k.value));
}
}

public long weight() {
return cache.weight();
}

public int count() {
return cache.count();
}

public static <K, V> Builder<K, V> builder() {
return new Builder<>();
}

public static class Builder<K, V> {
private long maximumWeight;
private TimeValue expireAfterAccessNanos;
private TimeValue expireAfterWrite;
private ToLongBiFunction<K, V> weigher;
private RemovalListener<K, V> removalListener;

private Builder() {}

public Builder<K, V> setMaximumWeight(long maximumWeight) {
if (maximumWeight < 0) {
throw new IllegalArgumentException("maximumWeight < 0");
}
this.maximumWeight = maximumWeight;
return this;
}

public Builder<K, V> setExpireAfterAccess(TimeValue expireAfterAccess) {
Objects.requireNonNull(expireAfterAccess);
final long expireAfterAccessNanos = expireAfterAccess.getNanos();
if (expireAfterAccessNanos <= 0) {
throw new IllegalArgumentException("expireAfterAccess <= 0");
}
this.expireAfterAccessNanos = expireAfterAccess;
return this;
}

public Builder<K, V> setExpireAfterWrite(TimeValue expireAfterWrite) {
Objects.requireNonNull(expireAfterWrite);
final long expireAfterWriteNanos = expireAfterWrite.getNanos();
if (expireAfterWriteNanos <= 0) {
throw new IllegalArgumentException("expireAfterWrite <= 0");
}
this.expireAfterWrite = expireAfterWrite;
return this;
}

public Builder<K, V> weigher(ToLongBiFunction<K, V> weigher) {
Objects.requireNonNull(weigher);
this.weigher = weigher;
return this;
}

public Builder<K, V> removalListener(RemovalListener<K, V> removalListener) {
Objects.requireNonNull(removalListener);
this.removalListener = removalListener;
return this;
}

public ProjectScopedCache<K, V> build(ProjectResolver projectResolver) {
CacheBuilder<ProjectScoped<K>, V> cacheBuilder = CacheBuilder.builder();

if (maximumWeight != -1) {
cacheBuilder.setMaximumWeight(maximumWeight);
}
if (expireAfterAccessNanos != null) {
cacheBuilder.setExpireAfterAccess(expireAfterAccessNanos);
}
if (expireAfterWrite != null) {
cacheBuilder.setExpireAfterWrite(expireAfterWrite);
}
if (weigher != null) {
cacheBuilder.weigher((key, value) -> weigher.applyAsLong(key.value, value));
}
if (removalListener != null) {
cacheBuilder.removalListener((notification) -> {
removalListener.onRemoval(
new RemovalNotification<>(notification.getKey().value, notification.getValue(), notification.getRemovalReason())
);
});
}
return new ProjectScopedCache<>(projectResolver, cacheBuilder.build());
}
}

private record ProjectScoped<T>(ProjectId projectId, T value) {
private ProjectScoped(ProjectId projectId, T value) {
this.projectId = Objects.requireNonNull(projectId);
this.value = value;
}
}
}
Loading
Loading