diff --git a/CHANGELOG.md b/CHANGELOG.md index 728be79d55..9a866a1c3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements +- Moved configuration reloading to dedicated thread to improve node stability ([#5479](https://github.com/opensearch-project/security/pull/5479)) - Makes resource settings dynamic ([#5677](https://github.com/opensearch-project/security/pull/5677)) ### Bug Fixes diff --git a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java index e8e00145e7..cc199e1f51 100644 --- a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java +++ b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java @@ -1185,6 +1185,7 @@ public Collection createComponents( final XFFResolver xffResolver = new XFFResolver(threadPool); backendRegistry = new BackendRegistry(settings, adminDns, xffResolver, auditLog, threadPool, cih); backendRegistry.registerClusterSettingsChangeListener(clusterService.getClusterSettings()); + cr.subscribeOnChange(configMap -> { backendRegistry.invalidateCache(); }); tokenManager = new SecurityTokenManager(cs, threadPool, userService); final CompatConfig compatConfig = new CompatConfig(environment, transportPassiveAuthSetting); diff --git a/src/main/java/org/opensearch/security/action/configupdate/TransportConfigUpdateAction.java b/src/main/java/org/opensearch/security/action/configupdate/TransportConfigUpdateAction.java index ad1ebc08e5..afb805341e 100644 --- a/src/main/java/org/opensearch/security/action/configupdate/TransportConfigUpdateAction.java +++ b/src/main/java/org/opensearch/security/action/configupdate/TransportConfigUpdateAction.java @@ -35,22 +35,22 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.security.auth.BackendRegistry; import org.opensearch.security.configuration.ConfigurationRepository; -import org.opensearch.security.securityconf.DynamicConfigFactory; import org.opensearch.security.securityconf.impl.CType; +import org.opensearch.security.util.TransportNodesAsyncAction; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportService; -public class TransportConfigUpdateAction extends TransportNodesAction< +public class TransportConfigUpdateAction extends TransportNodesAsyncAction< ConfigUpdateRequest, ConfigUpdateResponse, TransportConfigUpdateAction.NodeConfigUpdateRequest, @@ -59,7 +59,6 @@ public class TransportConfigUpdateAction extends TransportNodesAction< protected Logger logger = LogManager.getLogger(getClass()); private final Provider backendRegistry; private final ConfigurationRepository configurationRepository; - private DynamicConfigFactory dynamicConfigFactory; private static final Set> SELECTIVE_VALIDATION_TYPES = Set.of(CType.INTERNALUSERS); // Note: While INTERNALUSERS is used as a marker, the cache invalidation // applies to all user types (internal, LDAP, etc.) @@ -72,8 +71,7 @@ public TransportConfigUpdateAction( final TransportService transportService, final ConfigurationRepository configurationRepository, final ActionFilters actionFilters, - Provider backendRegistry, - DynamicConfigFactory dynamicConfigFactory + Provider backendRegistry ) { super( ConfigUpdateAction.NAME, @@ -84,12 +82,12 @@ public TransportConfigUpdateAction( ConfigUpdateRequest::new, TransportConfigUpdateAction.NodeConfigUpdateRequest::new, ThreadPool.Names.MANAGEMENT, + ThreadPool.Names.SAME, ConfigUpdateNodeResponse.class ); this.configurationRepository = configurationRepository; this.backendRegistry = backendRegistry; - this.dynamicConfigFactory = dynamicConfigFactory; } public static class NodeConfigUpdateRequest extends TransportRequest { @@ -128,17 +126,29 @@ protected ConfigUpdateResponse newResponse( } @Override - protected ConfigUpdateNodeResponse nodeOperation(final NodeConfigUpdateRequest request) { + protected void nodeOperation(NodeConfigUpdateRequest request, ActionListener listener) { final var configupdateRequest = request.request; if (canHandleSelectively(configupdateRequest)) { backendRegistry.get().invalidateUserCache(configupdateRequest.getEntityNames()); + listener.onResponse(new ConfigUpdateNodeResponse(clusterService.localNode(), configupdateRequest.getConfigTypes(), null)); } else { - boolean didReload = configurationRepository.reloadConfiguration(CType.fromStringValues((configupdateRequest.getConfigTypes()))); - if (didReload) { - backendRegistry.get().invalidateCache(); - } + configurationRepository.reloadConfiguration( + CType.fromStringValues((configupdateRequest.getConfigTypes())), + new ActionListener<>() { + @Override + public void onResponse(ConfigurationRepository.ConfigReloadResponse configReloadResponse) { + listener.onResponse( + new ConfigUpdateNodeResponse(clusterService.localNode(), configupdateRequest.getConfigTypes(), null) + ); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + } + ); } - return new ConfigUpdateNodeResponse(clusterService.localNode(), configupdateRequest.getConfigTypes(), null); } private boolean canHandleSelectively(ConfigUpdateRequest request) { diff --git a/src/main/java/org/opensearch/security/configuration/ConfigUpdateAlreadyInProgressException.java b/src/main/java/org/opensearch/security/configuration/ConfigUpdateAlreadyInProgressException.java deleted file mode 100644 index 6387c17103..0000000000 --- a/src/main/java/org/opensearch/security/configuration/ConfigUpdateAlreadyInProgressException.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2015-2019 floragunn GmbH - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package org.opensearch.security.configuration; - -import java.io.IOException; - -import org.opensearch.OpenSearchException; -import org.opensearch.core.common.io.stream.StreamInput; - -public class ConfigUpdateAlreadyInProgressException extends OpenSearchException { - - public ConfigUpdateAlreadyInProgressException(StreamInput in) throws IOException { - super(in); - } - - public ConfigUpdateAlreadyInProgressException(String msg, Object... args) { - super(msg, args); - } - - public ConfigUpdateAlreadyInProgressException(String msg, Throwable cause, Object... args) { - super(msg, cause, args); - } - - public ConfigUpdateAlreadyInProgressException(Throwable cause) { - super(cause); - } - -} diff --git a/src/main/java/org/opensearch/security/configuration/ConfigurationRepository.java b/src/main/java/org/opensearch/security/configuration/ConfigurationRepository.java index 737f86ba74..58042f868b 100644 --- a/src/main/java/org/opensearch/security/configuration/ConfigurationRepository.java +++ b/src/main/java/org/opensearch/security/configuration/ConfigurationRepository.java @@ -28,13 +28,10 @@ import java.io.File; import java.nio.file.Path; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.text.SimpleDateFormat; +import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; -import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,14 +41,14 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -71,6 +68,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Priority; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; import org.opensearch.core.action.ActionListener; @@ -92,7 +90,6 @@ import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.ConfigHelper; import org.opensearch.security.support.SecurityIndexHandler; -import org.opensearch.security.support.SecurityUtils; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.Client; @@ -124,6 +121,8 @@ public class ConfigurationRepository implements ClusterStateListener, IndexEvent private final SecurityIndexHandler securityIndexHandler; + private final ReloadThread reloadThread; + // visible for testing protected ConfigurationRepository( final String securityIndex, @@ -148,6 +147,7 @@ protected ConfigurationRepository( this.cl = configurationLoaderSecurity7; configCache = CacheBuilder.newBuilder().build(); this.securityIndexHandler = securityIndexHandler; + this.reloadThread = new ReloadThread(settings, this::doReload); } private Path resolveConfigDir() { @@ -280,7 +280,7 @@ private void initalizeClusterConfiguration(final boolean installDefaultConfig) { while (!dynamicConfigFactory.isInitialized()) { try { LOGGER.debug("Try to load config ..."); - reloadConfiguration(CType.values(), true); + doReload(CType.values()); break; } catch (Exception e) { LOGGER.debug("Unable to load configuration due to {}", String.valueOf(ExceptionUtils.getRootCause(e))); @@ -414,6 +414,7 @@ Future executeConfigurationInitialization(final SecurityMetadata securityM setupAuditConfigurationIfAny(auditConfigDocPresent); auditHotReloadingEnabled.getAndSet(auditConfigDocPresent); initalizeConfigTask.complete(null); + this.reloadThread.start(); LOGGER.info( "Security configuration initialized. Applied hashes: {}", securityMetadata.configuration() @@ -435,8 +436,13 @@ public CompletableFuture initOnNodeStart() { final Supplier> startInitialization = () -> { new Thread(() -> { - initalizeClusterConfiguration(installDefaultConfig); - initalizeConfigTask.complete(null); + try { + initalizeClusterConfiguration(installDefaultConfig); + initalizeConfigTask.complete(null); + } finally { + // After initialization is complete, start the update thread so that we execute any pending update requests + this.reloadThread.start(); + } }).start(); return initalizeConfigTask.thenApply(result -> installDefaultConfig); }; @@ -458,6 +464,7 @@ public CompletableFuture initOnNodeStart() { securityIndex ); initalizeConfigTask.complete(null); + this.reloadThread.start(); return initalizeConfigTask.thenApply(result -> installDefaultConfig); } } catch (Throwable e2) { @@ -518,40 +525,25 @@ public SecurityDynamicConfiguration getConfiguration(CType configurati return SecurityDynamicConfiguration.empty(configurationType); } - private final Lock LOCK = new ReentrantLock(); - - public boolean reloadConfiguration(final Collection> configTypes) throws ConfigUpdateAlreadyInProgressException { - return reloadConfiguration(configTypes, false); - } - - private boolean reloadConfiguration(final Collection> configTypes, final boolean fromBackgroundThread) - throws ConfigUpdateAlreadyInProgressException { - if (!fromBackgroundThread && !initalizeConfigTask.isDone()) { - LOGGER.warn("Unable to reload configuration, initalization thread has not yet completed."); - return false; - } - return loadConfigurationWithLock(configTypes); - } + /** + * Requests a reload of the currently used configuration. If a configuration update is currently in progress, + * another update will be queued. This method will not queue several updates; rather, it will combine several + * updates into one. - private boolean loadConfigurationWithLock(Collection> configTypes) { - try { - if (LOCK.tryLock(60, TimeUnit.SECONDS)) { - try { - reloadConfiguration0(configTypes, this.acceptInvalid); - return true; - } finally { - LOCK.unlock(); - } - } else { - throw new ConfigUpdateAlreadyInProgressException("A config update is already in progress"); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new ConfigUpdateAlreadyInProgressException("Interrupted config update"); - } + * @param configTypes the configuration types to be reloaded. + * @param listener an listener to be notified when the reload was finished. You can provide null if you do not want + * such a notification + */ + public void reloadConfiguration(Collection> configTypes, ActionListener listener) { + this.reloadThread.requestReload(configTypes, listener); } - private void reloadConfiguration0(Collection> configTypes, boolean acceptInvalid) { + /** + * Reloads the currently used configuration. Usually, you should not call this directly. Rather, use the reloadConfiguration() methods. + * This method should be only called directly via the update or initialization threads in order to make sure that only one + * reload is active at the same time. + */ + private void doReload(Set> configTypes) { ConfigurationMap loaded = getConfigurationsFromIndex(configTypes, false, acceptInvalid); notifyConfigurationListeners(loaded); } @@ -645,35 +637,10 @@ private ConfigurationMap validate(ConfigurationMap conf, int expectedSize) throw return conf; } - private static String formatDate(long date) { - return new SimpleDateFormat("yyyy-MM-dd", SecurityUtils.EN_Locale).format(new Date(date)); - } - public static int getDefaultConfigVersion() { return ConfigurationRepository.DEFAULT_CONFIG_VERSION; } - @SuppressWarnings("removal") - private class AccessControllerWrappedThread extends Thread { - private final Thread innerThread; - - public AccessControllerWrappedThread(Thread innerThread) { - this.innerThread = innerThread; - } - - @Override - public void run() { - AccessController.doPrivileged(new PrivilegedAction() { - - @Override - public Void run() { - innerThread.run(); - return null; - } - }); - } - } - @Override public void afterIndexShardStarted(IndexShard indexShard) { final ShardId shardId = indexShard.shardId(); @@ -686,10 +653,197 @@ public void afterIndexShardStarted(IndexShard indexShard) { threadPool.generic().execute(() -> { if (isSecurityIndexRestoredFromSnapshot(clusterService, index, securityIndex)) { LOGGER.info("Security index primary shard {} started - config reloading for snapshot restore", shardId); - reloadConfiguration(CType.values()); + reloadConfiguration(CType.values(), null); } }); } } } + + /** + * This class is responsible for managing requests to reload the security index. Its main purpose + * is to make sure that there is no unbounded queue of reload requests. Rather, it works this way: + *
    + *
  • If there is no reload activity, just schedule it immediately.
  • + *
  • If a reload is currently in process, schedule a further reload right afterwards.
  • + *
  • If a reload is currently in process and a further reload is already scheduled, just rely on the already scheduled reload. + * If there are configuration types requested to be reloaded, which are not scheduled so far, the requested configuration + * types of the scheduled reload are expanded.
  • + *
+ * Reloading will always take place on a single, dedicated thread. + *

+ * After an instance of this class has been created, the thread won't be running yet. You need to manually + * call the start() method to start the thread. This is to allow initialization code to run without having the + * thread already interfering. However, this also means that you must sure that you do not forget to call the + * start() method. Not calling it means that a cluster won't be able to get security updates. + */ + static class ReloadThread { + + private final Consumer>> performFunction; + private final Thread thread; + private final Object requestLock = new Object(); + private boolean started = false; + + /** + * This is the request queue - even though it is not actually a queue. We collect here + * the configuration types for which a reload was requested but not yet performed. + * Several consecutive requests will just extend this collection - if necessary. + */ + private ImmutableSet> reloadRequestedFor = ImmutableSet.of(); + + /** + * Action listeners to be called when the reload was finished. We collect the action listeners here until + * the reload is actually in progress. + */ + private List> reloadRequestedForActionListeners = new ArrayList<>(); + + /** + * The time we got the first currently queued reload request. + */ + private Instant reloadRequestedAt; + + /** + * This contains the configuration types for which a reload is in progress just right now. + */ + private ImmutableSet> reloadInProgressFor = ImmutableSet.of(); + + ReloadThread(Settings settings, Consumer>> performFunction) { + this.performFunction = performFunction; + this.thread = OpenSearchExecutors.daemonThreadFactory(settings, "ConfigurationRepository#ReloadThread").newThread(this::run); + } + + /** + * Requests an async configuration reload for the given configuration types. Calling this method + * will not wait for the configuration reload to complete. + */ + void requestReload(Collection> configurationTypes, ActionListener actionListener) { + synchronized (this.requestLock) { + if (!this.started) { + LOGGER.info("Cannot reload configuration yet, because the initialization process did not complete yet"); + } + + if (actionListener != null) { + this.reloadRequestedForActionListeners.add(actionListener); + } + + if (this.reloadRequestedFor.isEmpty()) { + LOGGER.debug("Configuration reload request received for {}; notifying update thread", configurationTypes); + this.reloadRequestedAt = Instant.now(); + this.reloadRequestedFor = ImmutableSet.copyOf(configurationTypes); + this.requestLock.notifyAll(); + } else if (!this.reloadRequestedFor.containsAll(configurationTypes)) { + LOGGER.debug( + "Configuration reload request received for {}; adding new configuration types to already requested {}", + configurationTypes, + this.reloadRequestedFor + ); + this.reloadRequestedFor = ImmutableSet.>builder() + .addAll(this.reloadRequestedFor) + .addAll(configurationTypes) + .build(); + } else { + if (Duration.between(this.reloadRequestedAt, Instant.now()).toMillis() > 30000) { + // Reload request is queued for more than 30 seconds; let us log a warning about that + LOGGER.warn( + "Configuration reload request received; another update request is already queued since {}", + this.reloadRequestedAt + ); + } else { + LOGGER.debug( + "Configuration reload request received; another update request is already queued since {}", + this.reloadRequestedAt + ); + } + } + } + } + + /** + * Starts the reload thread. Calling this method after the thread was already started will have no further effect. + */ + void start() { + synchronized (this.requestLock) { + if (!this.started) { + this.thread.start(); + this.started = true; + } + } + } + + /** + * Returns true if no reload is in progress and no reload has been queued. + */ + boolean isIdle() { + synchronized (this.requestLock) { + return this.reloadRequestedFor.isEmpty() && this.reloadInProgressFor.isEmpty(); + } + } + + /** + * Returns true if nothing is queued. Still, an active reload might be in progress. + */ + boolean queueIsEmpty() { + synchronized (this.requestLock) { + return this.reloadRequestedFor.isEmpty(); + } + } + + private void run() { + for (;;) { + ImmutableSet> localReloadRequestedFor; + List> localReloadRequestedForActionListeners = null; + try { + + synchronized (this.requestLock) { + this.reloadInProgressFor = ImmutableSet.of(); + + while (this.reloadRequestedFor.isEmpty()) { + this.requestLock.wait(); + } + + // We save here the requested configuration types in order to pass them to the updateFunction later + localReloadRequestedFor = this.reloadRequestedFor; + localReloadRequestedForActionListeners = new ArrayList<>(this.reloadRequestedForActionListeners); + + LOGGER.info( + "Performing configuration reload for request at {} on {}", + this.reloadRequestedAt, + localReloadRequestedFor + ); + + // Already set updateRequestedAt to null now. Thus, any further updates that come in during the + // following update process will be already recognized again and queued. + this.reloadRequestedAt = null; + this.reloadRequestedFor = ImmutableSet.of(); + this.reloadRequestedForActionListeners.clear(); + this.reloadInProgressFor = localReloadRequestedFor; + } + + this.performFunction.accept(localReloadRequestedFor); + for (ActionListener listener : localReloadRequestedForActionListeners) { + listener.onResponse(new ConfigReloadResponse(localReloadRequestedFor)); + } + } catch (Exception e) { + LOGGER.error("Error in {}", this.thread.getName(), e); + if (localReloadRequestedForActionListeners != null) { + for (ActionListener listener : localReloadRequestedForActionListeners) { + listener.onFailure(e); + } + } + } + } + } + } + + public static class ConfigReloadResponse { + private final Set> reloadedConfigTypes; + + ConfigReloadResponse(Set> reloadedConfigTypes) { + this.reloadedConfigTypes = reloadedConfigTypes; + } + + public Set> getReloadedConfigTypes() { + return reloadedConfigTypes; + } + } } diff --git a/src/main/java/org/opensearch/security/util/TransportNodesAsyncAction.java b/src/main/java/org/opensearch/security/util/TransportNodesAsyncAction.java new file mode 100644 index 0000000000..644bb09074 --- /dev/null +++ b/src/main/java/org/opensearch/security/util/TransportNodesAsyncAction.java @@ -0,0 +1,307 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.util; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReferenceArray; + +import org.apache.logging.log4j.message.ParameterizedMessage; + +import org.opensearch.action.ActionRunnable; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.ChannelActionListener; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.NodeShouldNotConnectException; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +/** + * A variation of org.opensearch.action.support.nodes.TransportNodesAction where the node operation is executed asynchronously + * + * Based on https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/support/nodes/TransportNodesAction.java + * + * @opensearch.internal + */ +public abstract class TransportNodesAsyncAction< + NodesRequest extends BaseNodesRequest, + NodesResponse extends BaseNodesResponse, + NodeRequest extends TransportRequest, + NodeResponse extends BaseNodeResponse> extends HandledTransportAction { + + protected final ThreadPool threadPool; + protected final ClusterService clusterService; + protected final TransportService transportService; + protected final Class nodeResponseClass; + protected final String transportNodeAction; + + private final String finalExecutor; + + /** + * @param actionName action name + * @param threadPool thread-pool + * @param clusterService cluster service + * @param transportService transport service + * @param actionFilters action filters + * @param request node request writer + * @param nodeRequest node request reader + * @param nodeExecutor executor to execute node action on + * @param finalExecutor executor to execute final collection of all responses on + * @param nodeResponseClass class of the node responses + */ + protected TransportNodesAsyncAction( + String actionName, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Writeable.Reader request, + Writeable.Reader nodeRequest, + String nodeExecutor, + String finalExecutor, + Class nodeResponseClass + ) { + super(actionName, transportService, actionFilters, request); + this.threadPool = threadPool; + this.clusterService = Objects.requireNonNull(clusterService); + this.transportService = Objects.requireNonNull(transportService); + this.nodeResponseClass = Objects.requireNonNull(nodeResponseClass); + + this.transportNodeAction = actionName + "[n]"; + this.finalExecutor = finalExecutor; + transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new NodeTransportHandler()); + } + + @Override + protected void doExecute(Task task, NodesRequest request, ActionListener listener) { + new AsyncAction(task, request, listener).start(); + } + + /** + * Map the responses into {@code nodeResponseClass} responses and {@link FailedNodeException}s. + * + * @param request The associated request. + * @param nodesResponses All node-level responses + * @return Never {@code null}. + * @throws NullPointerException if {@code nodesResponses} is {@code null} + * @see #newResponse(BaseNodesRequest, List, List) + */ + protected NodesResponse newResponse(NodesRequest request, AtomicReferenceArray nodesResponses) { + final List responses = new ArrayList<>(); + final List failures = new ArrayList<>(); + + for (int i = 0; i < nodesResponses.length(); ++i) { + Object response = nodesResponses.get(i); + + if (response instanceof FailedNodeException) { + failures.add((FailedNodeException) response); + } else { + responses.add(nodeResponseClass.cast(response)); + } + } + + return newResponse(request, responses, failures); + } + + /** + * Create a new {@link NodesResponse} (multi-node response). + * + * @param request The associated request. + * @param responses All successful node-level responses. + * @param failures All node-level failures. + * @return Never {@code null}. + * @throws NullPointerException if any parameter is {@code null}. + */ + protected abstract NodesResponse newResponse(NodesRequest request, List responses, List failures); + + protected abstract NodeRequest newNodeRequest(NodesRequest request); + + protected abstract NodeResponse newNodeResponse(StreamInput in) throws IOException; + + protected abstract void nodeOperation(NodeRequest request, ActionListener listener); + + /** + * resolve node ids to concrete nodes of the incoming request + **/ + protected void resolveRequest(NodesRequest request, ClusterState clusterState) { + assert request.concreteNodes() == null : "request concreteNodes shouldn't be set"; + String[] nodesIds = clusterState.nodes().resolveNodes(request.nodesIds()); + request.setConcreteNodes(Arrays.stream(nodesIds).map(clusterState.nodes()::get).toArray(DiscoveryNode[]::new)); + } + + /** + * Get a backwards compatible transport action name + */ + protected String getTransportNodeAction(DiscoveryNode node) { + return transportNodeAction; + } + + /** + * Asynchronous action + * + * @opensearch.internal + */ + class AsyncAction { + + private final NodesRequest request; + private final ActionListener listener; + private final AtomicReferenceArray responses; + private final DiscoveryNode[] concreteNodes; + private final AtomicInteger counter = new AtomicInteger(); + private final Task task; + + AsyncAction(Task task, NodesRequest request, ActionListener listener) { + this.task = task; + this.request = request; + this.listener = listener; + if (request.concreteNodes() == null) { + resolveRequest(request, clusterService.state()); + assert request.concreteNodes() != null; + } + this.responses = new AtomicReferenceArray<>(request.concreteNodes().length); + this.concreteNodes = request.concreteNodes(); + // As we transfer the ownership of discovery nodes to route the request to into the AsyncAction class, + // we remove the list of DiscoveryNodes from the request. This reduces the payload of the request and improves + // the number of concrete nodes in the memory. + request.setConcreteNodes(null); + } + + void start() { + if (this.concreteNodes.length == 0) { + // nothing to notify + threadPool.generic().execute(() -> listener.onResponse(newResponse(request, responses))); + return; + } + TransportRequestOptions.Builder builder = TransportRequestOptions.builder(); + if (request.timeout() != null) { + builder.withTimeout(request.timeout()); + } + for (int i = 0; i < this.concreteNodes.length; i++) { + final int idx = i; + final DiscoveryNode node = this.concreteNodes[i]; + final String nodeId = node.getId(); + try { + TransportRequest nodeRequest = newNodeRequest(request); + if (task != null) { + nodeRequest.setParentTask(clusterService.localNode().getId(), task.getId()); + } + transportService.sendRequest( + node, + getTransportNodeAction(node), + nodeRequest, + builder.build(), + new TransportResponseHandler() { + @Override + public NodeResponse read(StreamInput in) throws IOException { + return newNodeResponse(in); + } + + @Override + public void handleResponse(NodeResponse response) { + onOperation(idx, response); + } + + @Override + public void handleException(TransportException exp) { + onFailure(idx, node.getId(), exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + } + ); + } catch (Exception e) { + onFailure(idx, nodeId, e); + } + } + } + + private void onOperation(int idx, NodeResponse nodeResponse) { + responses.set(idx, nodeResponse); + if (counter.incrementAndGet() == responses.length()) { + finishHim(); + } + } + + private void onFailure(int idx, String nodeId, Throwable t) { + if (logger.isDebugEnabled() && !(t instanceof NodeShouldNotConnectException)) { + logger.debug(new ParameterizedMessage("failed to execute on node [{}]", nodeId), t); + } + responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)); + if (counter.incrementAndGet() == responses.length()) { + finishHim(); + } + } + + private void finishHim() { + threadPool.executor(finalExecutor).execute(ActionRunnable.supply(listener, () -> newResponse(request, responses))); + } + } + + /** + * A node transport handler + * + * @opensearch.internal + */ + class NodeTransportHandler implements TransportRequestHandler { + + @Override + public void messageReceived(NodeRequest request, TransportChannel channel, Task task) throws Exception { + nodeOperation(request, new ChannelActionListener<>(channel, actionName, request)); + } + } + +} diff --git a/src/test/java/org/opensearch/security/configuration/ConfigurationRepositoryReloadThreadTest.java b/src/test/java/org/opensearch/security/configuration/ConfigurationRepositoryReloadThreadTest.java new file mode 100644 index 0000000000..c62cd07e7d --- /dev/null +++ b/src/test/java/org/opensearch/security/configuration/ConfigurationRepositoryReloadThreadTest.java @@ -0,0 +1,186 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.configuration; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; + +import org.opensearch.common.settings.Settings; +import org.opensearch.node.Node; +import org.opensearch.security.securityconf.impl.CType; + +import static org.awaitility.Awaitility.await; +import static org.junit.Assert.assertEquals; + +public class ConfigurationRepositoryReloadThreadTest { + + static final Settings settings = Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test_node").build(); + + @Test + public void singleRequest() { + Set> requestedConfigTypes = Set.of(CType.INTERNALUSERS, CType.ROLES); + AtomicInteger reloadCounter = new AtomicInteger(0); + Set> reloadedConfigTypes = Collections.synchronizedSet(new HashSet<>()); + ConfigurationRepository.ReloadThread subject = new ConfigurationRepository.ReloadThread(settings, (configTypes) -> { + reloadCounter.incrementAndGet(); + reloadedConfigTypes.addAll(configTypes); + }); + subject.start(); + subject.requestReload(requestedConfigTypes, null); + + await().until(subject::isIdle); + assertEquals("Exactly one reload should have been performed after the reload request", 1, reloadCounter.get()); + assertEquals("The reloaded config types match the requested config types", requestedConfigTypes, reloadedConfigTypes); + } + + @Test + public void twoRequestsBeforeStart() { + AtomicInteger reloadCounter = new AtomicInteger(0); + Set> reloadedConfigTypes = Collections.synchronizedSet(new HashSet<>()); + ConfigurationRepository.ReloadThread subject = new ConfigurationRepository.ReloadThread(settings, (configTypes) -> { + reloadCounter.incrementAndGet(); + reloadedConfigTypes.addAll(configTypes); + }); + subject.requestReload(Set.of(CType.INTERNALUSERS), null); + subject.requestReload(Set.of(CType.ROLES), null); + subject.start(); + + await().until(subject::isIdle); + assertEquals("Exactly one reload should have been performed after the reload request", 1, reloadCounter.get()); + assertEquals( + "The reloaded config types match the requested config types", + Set.of(CType.INTERNALUSERS, CType.ROLES), + reloadedConfigTypes + ); + } + + @Test + public void oneQueuedRequest() { + AtomicInteger reloadCounter = new AtomicInteger(0); + // The following boolean allows us to synchronize between the reload code and the assertion for testing purposes. This helps to + // avoid using Thread.sleep() calls. + AtomicBoolean reloadContinueCondition = new AtomicBoolean(false); + Set> reloadedConfigTypes = Collections.synchronizedSet(new HashSet<>()); + ConfigurationRepository.ReloadThread subject = new ConfigurationRepository.ReloadThread(settings, (configTypes) -> { + reloadCounter.incrementAndGet(); + reloadedConfigTypes.addAll(configTypes); + await().until(reloadContinueCondition::get); + }); + subject.start(); + subject.requestReload(Set.of(CType.INTERNALUSERS), null); + await().until(subject::queueIsEmpty); + + subject.requestReload(Set.of(CType.ROLES), null); + + // Signal the reload function to finish + reloadContinueCondition.set(true); + + await().until(subject::isIdle); + assertEquals("Two reload requests have been performed now", 2, reloadCounter.get()); + assertEquals( + "The reloaded config types match the requested config types", + Set.of(CType.INTERNALUSERS, CType.ROLES), + reloadedConfigTypes + ); + } + + @Test + public void twoQueuedRequests() { + AtomicInteger reloadCounter = new AtomicInteger(0); + // The following boolean allows us to synchronize between the reload code and the assertion for testing purposes. This helps to + // avoid using Thread.sleep() calls. + AtomicBoolean reloadContinueCondition = new AtomicBoolean(false); + Set> reloadedConfigTypes = Collections.synchronizedSet(new HashSet<>()); + ConfigurationRepository.ReloadThread subject = new ConfigurationRepository.ReloadThread(settings, (configTypes) -> { + reloadCounter.incrementAndGet(); + reloadedConfigTypes.addAll(configTypes); + await().until(reloadContinueCondition::get); + }); + subject.start(); + subject.requestReload(Set.of(CType.INTERNALUSERS), null); + await().until(subject::queueIsEmpty); + + subject.requestReload(Set.of(CType.ROLES), null); + subject.requestReload(Set.of(CType.ROLESMAPPING), null); + + // Signal the reload function to finish + reloadContinueCondition.set(true); + + await().until(subject::isIdle); + assertEquals("Two reload requests have been performed now", 2, reloadCounter.get()); + assertEquals( + "The reloaded config types match the requested config types", + Set.of(CType.INTERNALUSERS, CType.ROLES, CType.ROLESMAPPING), + reloadedConfigTypes + ); + } + + @Test + public void twoQueuedRequestsWithoutTypeChange() { + AtomicInteger reloadCounter = new AtomicInteger(0); + // The following boolean allows us to synchronize between the reload code and the assertion for testing purposes. This helps to + // avoid using Thread.sleep() calls. + AtomicBoolean reloadContinueCondition = new AtomicBoolean(false); + Set> reloadedConfigTypes = Collections.synchronizedSet(new HashSet<>()); + ConfigurationRepository.ReloadThread subject = new ConfigurationRepository.ReloadThread(settings, (configTypes) -> { + reloadCounter.incrementAndGet(); + reloadedConfigTypes.addAll(configTypes); + await().until(reloadContinueCondition::get); + }); + subject.start(); + subject.requestReload(Set.of(CType.INTERNALUSERS), null); + await().until(subject::queueIsEmpty); + + subject.requestReload(Set.of(CType.ROLES, CType.ROLESMAPPING), null); + subject.requestReload(Set.of(CType.ROLESMAPPING), null); + + // Signal the reload function to finish + reloadContinueCondition.set(true); + + await().until(subject::isIdle); + assertEquals("Two reload requests have been performed now", 2, reloadCounter.get()); + assertEquals( + "The reloaded config types match the requested config types", + Set.of(CType.INTERNALUSERS, CType.ROLES, CType.ROLESMAPPING), + reloadedConfigTypes + ); + } + + @Test + public void threadContinuesDespiteException() { + AtomicInteger reloadCounter = new AtomicInteger(0); + Set> reloadedConfigTypes = Collections.synchronizedSet(new HashSet<>()); + ConfigurationRepository.ReloadThread subject = new ConfigurationRepository.ReloadThread(settings, (configTypes) -> { + reloadCounter.incrementAndGet(); + reloadedConfigTypes.addAll(configTypes); + if (configTypes.contains(CType.AUDIT)) { + // We use the config type AUDIT to request an exception for testing + throw new RuntimeException("Throwing exception, as requested"); + } + }); + subject.start(); + subject.requestReload(Set.of(CType.AUDIT), null); + await().until(subject::queueIsEmpty); + + subject.requestReload(Set.of(CType.ROLES), null); + + await().until(subject::isIdle); + assertEquals("Two reload requests have been performed now", 2, reloadCounter.get()); + assertEquals("The reloaded config types match the requested config types", Set.of(CType.AUDIT, CType.ROLES), reloadedConfigTypes); + } + +} diff --git a/src/test/java/org/opensearch/security/configuration/ConfigurationRepositoryTest.java b/src/test/java/org/opensearch/security/configuration/ConfigurationRepositoryTest.java index 555c495955..c3ea7c80ed 100644 --- a/src/test/java/org/opensearch/security/configuration/ConfigurationRepositoryTest.java +++ b/src/test/java/org/opensearch/security/configuration/ConfigurationRepositoryTest.java @@ -15,9 +15,7 @@ import java.io.IOException; import java.nio.file.Path; import java.time.Instant; -import java.util.Collections; import java.util.Set; -import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeoutException; import com.fasterxml.jackson.databind.InjectableValues; @@ -36,22 +34,17 @@ import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.ClusterStateUpdateTask; -import org.opensearch.cluster.RestoreInProgress; import org.opensearch.cluster.block.ClusterBlocks; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Priority; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.Index; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; -import org.opensearch.index.shard.IndexShard; import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.securityconf.DynamicConfigFactory; @@ -70,7 +63,6 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; import org.mockito.stubbing.OngoingStubbing; @@ -90,11 +82,9 @@ import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doCallRealMethod; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -593,57 +583,6 @@ public void getConfigurationsFromIndex_SecurityIndexNotInitiallyReady() throws I assertThat(result.size(), is(CType.values().size())); } - @Test - public void afterIndexShardStarted_whenSecurityIndexUpdated() throws InterruptedException, TimeoutException { - Settings settings = Settings.builder().build(); - IndexShard indexShard = mock(IndexShard.class); - ShardRouting shardRouting = mock(ShardRouting.class); - ShardId shardId = mock(ShardId.class); - Index index = mock(Index.class); - ClusterState mockClusterState = mock(ClusterState.class); - RestoreInProgress mockRestore = mock(RestoreInProgress.class); - RestoreInProgress.Entry mockEntry = mock(RestoreInProgress.Entry.class); - ExecutorService executorService = mock(ExecutorService.class); - ThreadPool threadPool = mock(ThreadPool.class); - ConfigurationRepository configurationRepository = spy(createConfigurationRepository(settings, threadPool)); - - // Setup mock behavior - when(indexShard.shardId()).thenReturn(shardId); - when(shardId.getIndex()).thenReturn(index); - when(index.getName()).thenReturn(ConfigConstants.OPENDISTRO_SECURITY_DEFAULT_CONFIG_INDEX); - when(indexShard.routingEntry()).thenReturn(shardRouting); - when(clusterService.state()).thenReturn(mockClusterState); - when(mockClusterState.custom(RestoreInProgress.TYPE)).thenReturn(mockRestore); - when(threadPool.generic()).thenReturn(executorService); - - // when replica shard updated - when(shardRouting.primary()).thenReturn(false); - configurationRepository.afterIndexShardStarted(indexShard); - verify(executorService, never()).execute(any()); - verify(configurationRepository, never()).reloadConfiguration(any()); - - // when primary shard updated - doReturn(true).when(configurationRepository).reloadConfiguration(any()); - when(shardRouting.primary()).thenReturn(true); - when(mockRestore.iterator()).thenReturn(Collections.singletonList(mockEntry).iterator()); - when(mockEntry.indices()).thenReturn(Collections.singletonList(ConfigConstants.OPENDISTRO_SECURITY_DEFAULT_CONFIG_INDEX)); - ArgumentCaptor successRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); - configurationRepository.afterIndexShardStarted(indexShard); - verify(executorService).execute(successRunnableCaptor.capture()); - successRunnableCaptor.getValue().run(); - verify(configurationRepository).reloadConfiguration(CType.values()); - - // When there is error in checking if restored from snapshot - Mockito.reset(configurationRepository, executorService); - ArgumentCaptor errorRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); - when(clusterService.state()).thenThrow(new RuntimeException("ClusterState exception")); - when(shardRouting.primary()).thenReturn(true); - configurationRepository.afterIndexShardStarted(indexShard); - verify(executorService).execute(errorRunnableCaptor.capture()); - errorRunnableCaptor.getValue().run(); - verify(configurationRepository, never()).reloadConfiguration(any()); - } - void assertClusterState(final ArgumentCaptor clusterStateUpdateTaskCaptor) throws Exception { final var initializedStateUpdate = clusterStateUpdateTaskCaptor.getValue(); assertThat(initializedStateUpdate.priority(), is(Priority.IMMEDIATE));