diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 2ada84553af17..1b3f836401c1f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -248,7 +248,9 @@ public List getRestHandlers( @Override public Collection createComponents(PluginServices services) { var components = new ArrayList<>(); - var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService()); + var throttlerManager = new ThrottlerManager(settings, services.threadPool()); + throttlerManager.init(services.clusterService()); + var truncator = new Truncator(settings, services.clusterService()); serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator)); threadPoolSetOnce.set(services.threadPool()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/Throttler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/Throttler.java index 0cf0e65eaba37..2eb0b1b207073 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/Throttler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/Throttler.java @@ -7,152 +7,222 @@ package org.elasticsearch.xpack.inference.logging; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogBuilder; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import java.io.Closeable; import java.time.Clock; -import java.time.Duration; import java.time.Instant; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; +import java.util.concurrent.locks.ReentrantReadWriteLock; -import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; /** - * A class that throttles calls to a logger. If a log call is made during the throttle period a counter is incremented. - * If a log call occurs after the throttle period, then the call will proceed, and it will include a message like - * "repeated X times" to indicate how often the message was attempting to be logged. + * A class that throttles calls to a logger. The first unique log message is permitted to emit a message. Any subsequent log messages + * matching a message that has already been emitted will only increment a counter. A thread runs on an interval + * to emit any log messages that have been repeated beyond the initial emitted message. Once the thread emits a repeated + * message the counter is reset. If another message is received matching a previously emitted message by the thread, it will be consider + * the first time a unique message is received and will be logged. */ public class Throttler implements Closeable { private static final Logger classLogger = LogManager.getLogger(Throttler.class); - private final TimeValue resetInterval; - private Duration durationToWait; + private final TimeValue loggingInterval; private final Clock clock; private final ConcurrentMap logExecutors; private final AtomicReference cancellableTask = new AtomicReference<>(); private final AtomicBoolean isRunning = new AtomicBoolean(true); + private final ThreadPool threadPool; + // This lock governs the ability of the utility thread to get exclusive access to remove entries + // from the map + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); /** - * Constructs the throttler and kicks of a scheduled tasks to clear the internal stats. - * - * @param resetInterval the frequency for clearing the internal stats. This protects against an ever growing - * cache - * @param durationToWait the amount of time to wait before logging a message after the threshold - * is reached + * @param loggingInterval the frequency to run a task to emit repeated log messages * @param threadPool a thread pool for running a scheduled task to clear the internal stats */ - public Throttler(TimeValue resetInterval, TimeValue durationToWait, ThreadPool threadPool) { - this(resetInterval, durationToWait, Clock.systemUTC(), threadPool, new ConcurrentHashMap<>()); + public Throttler(TimeValue loggingInterval, ThreadPool threadPool) { + this(loggingInterval, Clock.systemUTC(), threadPool, new ConcurrentHashMap<>()); + } + + /** + * @param oldThrottler a previous throttler that is being replaced + * @param loggingInterval the frequency to run a task to emit repeated log messages + */ + public Throttler(Throttler oldThrottler, TimeValue loggingInterval) { + this(loggingInterval, oldThrottler.clock, oldThrottler.threadPool, new ConcurrentHashMap<>(oldThrottler.logExecutors)); } /** * This should only be used directly for testing. */ - Throttler( - TimeValue resetInterval, - TimeValue durationToWait, - Clock clock, - ThreadPool threadPool, - ConcurrentMap logExecutors - ) { - Objects.requireNonNull(durationToWait); - Objects.requireNonNull(threadPool); - - this.resetInterval = Objects.requireNonNull(resetInterval); - this.durationToWait = Duration.ofMillis(durationToWait.millis()); + Throttler(TimeValue loggingInterval, Clock clock, ThreadPool threadPool, ConcurrentMap logExecutors) { + this.threadPool = Objects.requireNonNull(threadPool); + this.loggingInterval = Objects.requireNonNull(loggingInterval); this.clock = Objects.requireNonNull(clock); this.logExecutors = Objects.requireNonNull(logExecutors); + } - this.cancellableTask.set(startResetTask(threadPool)); + public void init() { + cancellableTask.set(startRepeatingLogEmitter()); } - private Scheduler.Cancellable startResetTask(ThreadPool threadPool) { - classLogger.debug(() -> format("Reset task scheduled with interval [%s]", resetInterval)); + private Scheduler.Cancellable startRepeatingLogEmitter() { + classLogger.debug(() -> Strings.format("Scheduling repeating log emitter with interval [%s]", loggingInterval)); + + return threadPool.scheduleWithFixedDelay(this::emitRepeatedLogs, loggingInterval, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + private void emitRepeatedLogs() { + if (isRunning.get() == false) { + return; + } + + final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock(); - return threadPool.scheduleWithFixedDelay(logExecutors::clear, resetInterval, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + writeLock.lock(); + try { + for (var iter = logExecutors.values().iterator(); iter.hasNext();) { + var executor = iter.next(); + executor.logRepeatedMessages(); + iter.remove(); + } + } finally { + writeLock.unlock(); + } } - public void setDurationToWait(TimeValue durationToWait) { - this.durationToWait = Duration.ofMillis(durationToWait.millis()); + public void execute(Logger logger, Level level, String message, Throwable t) { + executeInternal(logger, level, message, t); } - public void execute(String message, Consumer consumer) { + public void execute(Logger logger, Level level, String message) { + executeInternal(logger, level, message, null); + } + + private void executeInternal(Logger logger, Level level, String message, Throwable throwable) { if (isRunning.get() == false) { return; } - LogExecutor logExecutor = logExecutors.compute(message, (key, value) -> { - if (value == null) { - return new LogExecutor(clock, consumer); - } + final ReentrantReadWriteLock.ReadLock readLock = lock.readLock(); - return value.compute(consumer, durationToWait); - }); + readLock.lock(); + try { + var logExecutor = logExecutors.compute( + message, + (key, value) -> Objects.requireNonNullElseGet(value, () -> new LogExecutor(clock, logger, level, message, throwable)) + ); - // This executes an internal consumer that wraps the passed in one, it will either log the message passed here - // unchanged, do nothing if it is in the throttled period, or log this message + some text saying how many times it was repeated - logExecutor.log(message); + logExecutor.logFirstMessage(); + } finally { + readLock.unlock(); + } } @Override public void close() { isRunning.set(false); - cancellableTask.get().cancel(); - logExecutors.clear(); + if (cancellableTask.get() != null) { + cancellableTask.get().cancel(); + } + + clearLogExecutors(); + } + + private void clearLogExecutors() { + final ReentrantReadWriteLock.WriteLock writeLock = lock.writeLock(); + writeLock.lock(); + try { + logExecutors.clear(); + } finally { + writeLock.unlock(); + } } private static class LogExecutor { - private final long skippedLogCalls; - private final Instant timeOfLastLogCall; + // -1 here because we need to determine if we haven't logged the first time + // After the first time we'll set it to 0, then the thread that runs on an interval + // needs to know if there are any repeated message, if it sees 0, it knows there are none + // and skips emitting the message again. + private static final long INITIAL_LOG_COUNTER_VALUE = -1; + + private final AtomicLong skippedLogCalls = new AtomicLong(INITIAL_LOG_COUNTER_VALUE); + private final AtomicReference timeOfLastLogCall; private final Clock clock; - private final Consumer consumer; + private final Logger throttledLogger; + private final Level level; + private final String originalMessage; + private final Throwable throwable; - LogExecutor(Clock clock, Consumer throttledConsumer) { - this(clock, 0, throttledConsumer); - } - - LogExecutor(Clock clock, long skippedLogCalls, Consumer consumer) { - this.skippedLogCalls = skippedLogCalls; + LogExecutor(Clock clock, Logger logger, Level level, String originalMessage, @Nullable Throwable throwable) { this.clock = Objects.requireNonNull(clock); - timeOfLastLogCall = Instant.now(this.clock); - this.consumer = Objects.requireNonNull(consumer); + timeOfLastLogCall = new AtomicReference<>(Instant.now(this.clock)); + this.throttledLogger = Objects.requireNonNull(logger); + this.level = Objects.requireNonNull(level); + this.originalMessage = Objects.requireNonNull(originalMessage); + this.throwable = throwable; } - void log(String message) { - this.consumer.accept(message); + void logRepeatedMessages() { + var numSkippedLogCalls = skippedLogCalls.get(); + if (hasRepeatedLogsToEmit(numSkippedLogCalls) == false) { + return; + } + + String enrichedMessage; + if (numSkippedLogCalls == 1) { + enrichedMessage = Strings.format("%s, repeated 1 time, last message at [%s]", originalMessage, timeOfLastLogCall.get()); + } else { + enrichedMessage = Strings.format( + "%s, repeated %s times, last message at [%s]", + originalMessage, + skippedLogCalls, + timeOfLastLogCall.get() + ); + } + + log(enrichedMessage); } - LogExecutor compute(Consumer executor, Duration durationToWait) { - if (hasDurationExpired(durationToWait)) { - String messageToAppend = ""; - if (this.skippedLogCalls == 1) { - messageToAppend = ", repeated 1 time"; - } else if (this.skippedLogCalls > 1) { - messageToAppend = format(", repeated %s times", this.skippedLogCalls); - } - - final String stringToAppend = messageToAppend; - return new LogExecutor(this.clock, 0, (message) -> executor.accept(message.concat(stringToAppend))); + private void log(String enrichedMessage) { + LogBuilder builder = throttledLogger.atLevel(level); + if (throwable != null) { + builder = builder.withThrowable(throwable); } - // This creates a consumer that won't do anything because the original consumer is being throttled - return new LogExecutor(this.clock, this.skippedLogCalls + 1, (message) -> {}); + builder.log(enrichedMessage); + } + + private static boolean hasRepeatedLogsToEmit(long numSkippedLogCalls) { + return numSkippedLogCalls > 0; + } + + void logFirstMessage() { + timeOfLastLogCall.set(Instant.now(this.clock)); + + if (hasLoggedOriginalMessage(skippedLogCalls.getAndIncrement()) == false) { + log(originalMessage); + } } - private boolean hasDurationExpired(Duration durationToWait) { - Instant now = Instant.now(clock); - return now.isAfter(timeOfLastLogCall.plus(durationToWait)); + private static boolean hasLoggedOriginalMessage(long numSkippedLogCalls) { + // a negative value indicates that we haven't yet logged the original message + return numSkippedLogCalls >= 0; } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/ThrottlerManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/ThrottlerManager.java index d333cc92d61de..0e9fe3996d45d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/ThrottlerManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/logging/ThrottlerManager.java @@ -7,11 +7,13 @@ package org.elasticsearch.xpack.inference.logging; +import org.apache.logging.log4j.Level; import org.apache.logging.log4j.Logger; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.threadpool.ThreadPool; import java.io.Closeable; @@ -24,8 +26,11 @@ public class ThrottlerManager implements Closeable { private static final TimeValue DEFAULT_STATS_RESET_INTERVAL_TIME = TimeValue.timeValueDays(1); /** - * A setting specifying the interval for clearing the cached log message stats + * Legacy log throttling setting, kept for BWC compatibility. This setting has no effect in 9.1.0 and later. Do not use. + * TODO remove in 10.0 */ + @UpdateForV10(owner = UpdateForV10.Owner.MACHINE_LEARNING) + @Deprecated public static final Setting STATS_RESET_INTERVAL_SETTING = Setting.timeSetting( "xpack.inference.logging.reset_interval", DEFAULT_STATS_RESET_INTERVAL_TIME, @@ -35,8 +40,11 @@ public class ThrottlerManager implements Closeable { private static final TimeValue DEFAULT_WAIT_DURATION_TIME = TimeValue.timeValueHours(1); /** - * A setting specifying the amount of time to wait after a log call occurs before allowing another log call. + * Legacy log throttling setting, kept for BWC compatibility. This setting has no effect in 9.1.0 and later. Do not use. + * TODO remove in 10.0 */ + @UpdateForV10(owner = UpdateForV10.Owner.MACHINE_LEARNING) + @Deprecated public static final Setting LOGGER_WAIT_DURATION_SETTING = Setting.timeSetting( "xpack.inference.logging.wait_duration", DEFAULT_WAIT_DURATION_TIME, @@ -44,39 +52,42 @@ public class ThrottlerManager implements Closeable { Setting.Property.Dynamic ); - private final ThreadPool threadPool; - private Throttler throttler; - private LoggerSettings loggerSettings; + private static final TimeValue DEFAULT_LOG_EMIT_INTERVAL = TimeValue.timeValueHours(1); + /** + * This setting specifies how often a thread will run to emit repeated log messages. + */ + public static final Setting LOG_EMIT_INTERVAL = Setting.timeSetting( + "xpack.inference.logging.interval", + DEFAULT_LOG_EMIT_INTERVAL, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); - public ThrottlerManager(Settings settings, ThreadPool threadPool, ClusterService clusterService) { - Objects.requireNonNull(settings); - Objects.requireNonNull(clusterService); + private volatile TimeValue logInterval; - this.threadPool = Objects.requireNonNull(threadPool); - this.loggerSettings = LoggerSettings.fromSettings(settings); + private Throttler throttler; - throttler = new Throttler(loggerSettings.resetInterval(), loggerSettings.waitDuration(), threadPool); - this.addSettingsUpdateConsumers(clusterService); - } + public ThrottlerManager(Settings settings, ThreadPool threadPool) { + Objects.requireNonNull(settings); - private void addSettingsUpdateConsumers(ClusterService clusterService) { - clusterService.getClusterSettings().addSettingsUpdateConsumer(STATS_RESET_INTERVAL_SETTING, this::setResetInterval); - clusterService.getClusterSettings().addSettingsUpdateConsumer(LOGGER_WAIT_DURATION_SETTING, this::setWaitDuration); + throttler = new Throttler(LOG_EMIT_INTERVAL.get(settings), threadPool); + throttler.init(); } - // default for testing - void setWaitDuration(TimeValue waitDuration) { - loggerSettings = loggerSettings.createWithWaitDuration(waitDuration); + public void init(ClusterService clusterService) { + Objects.requireNonNull(clusterService); - throttler.setDurationToWait(waitDuration); + clusterService.getClusterSettings().addSettingsUpdateConsumer(LOG_EMIT_INTERVAL, this::setLogInterval); } // default for testing - void setResetInterval(TimeValue resetInterval) { - loggerSettings = loggerSettings.createWithResetInterval(resetInterval); + void setLogInterval(TimeValue logInterval) { + this.logInterval = logInterval; - throttler.close(); - throttler = new Throttler(loggerSettings.resetInterval(), loggerSettings.waitDuration(), threadPool); + var oldThrottler = throttler; + throttler = new Throttler(oldThrottler, this.logInterval); + throttler.init(); + oldThrottler.close(); } // default for testing @@ -88,13 +99,13 @@ public void warn(Logger logger, String message, Throwable e) { Objects.requireNonNull(message); Objects.requireNonNull(e); - throttler.execute(message, messageToLog -> logger.warn(messageToLog, e)); + throttler.execute(logger, Level.WARN, message, e); } public void warn(Logger logger, String message) { Objects.requireNonNull(message); - throttler.execute(message, logger::warn); + throttler.execute(logger, Level.WARN, message); } @Override @@ -103,25 +114,6 @@ public void close() { } public static List> getSettingsDefinitions() { - return List.of(STATS_RESET_INTERVAL_SETTING, LOGGER_WAIT_DURATION_SETTING); - } - - private record LoggerSettings(TimeValue resetInterval, TimeValue waitDuration) { - LoggerSettings { - Objects.requireNonNull(resetInterval); - Objects.requireNonNull(waitDuration); - } - - static LoggerSettings fromSettings(Settings settings) { - return new LoggerSettings(STATS_RESET_INTERVAL_SETTING.get(settings), LOGGER_WAIT_DURATION_SETTING.get(settings)); - } - - LoggerSettings createWithResetInterval(TimeValue resetInterval) { - return new LoggerSettings(resetInterval, waitDuration); - } - - LoggerSettings createWithWaitDuration(TimeValue waitDuration) { - return new LoggerSettings(resetInterval, waitDuration); - } + return List.of(LOG_EMIT_INTERVAL, STATS_RESET_INTERVAL_SETTING, LOGGER_WAIT_DURATION_SETTING); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java index e7160f0390669..41d50da0efe4c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerManagerTests.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.inference.logging; -import org.apache.logging.log4j.Logger; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.Scheduler; @@ -18,22 +18,23 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.logging.ThrottlerTests.mockLogger; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class ThrottlerManagerTests extends ESTestCase { - private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30); - private ThreadPool threadPool; + private DeterministicTaskQueue taskQueue; @Before public void init() { threadPool = createThreadPool(inferenceUtilityPool()); + taskQueue = new DeterministicTaskQueue(); } @After @@ -41,58 +42,94 @@ public void shutdown() { terminate(threadPool); } - public void testWarn_LogsOnlyOnce() { - var logger = mock(Logger.class); + public void testExecute_LogsOnlyOnce() { + var mockedLogger = mockLogger(); + + try (var throttler = new ThrottlerManager(Settings.EMPTY, taskQueue.getThreadPool())) { + throttler.init(mockClusterServiceEmpty()); - try (var throttler = new ThrottlerManager(Settings.EMPTY, threadPool, mockClusterServiceEmpty())) { - throttler.warn(logger, "test", new IllegalArgumentException("failed")); + throttler.warn(mockedLogger.logger(), "test", new IllegalArgumentException("failed")); + mockedLogger.verify(1, "test"); + mockedLogger.verifyThrowable(1); - verify(logger, times(1)).warn(eq("test"), any(Throwable.class)); + mockedLogger.clearInvocations(); - throttler.warn(logger, "test", new IllegalArgumentException("failed")); - verifyNoMoreInteractions(logger); + throttler.warn(mockedLogger.logger(), "test", new IllegalArgumentException("failed")); + mockedLogger.verifyNever(); + mockedLogger.verifyNoMoreInteractions(); } } - public void testWarn_AllowsDifferentMessagesToBeLogged() { - var logger = mock(Logger.class); + public void testExecute_AllowsDifferentMessagesToBeLogged() { + var mockedLogger = mockLogger(); + + try (var throttler = new ThrottlerManager(Settings.EMPTY, threadPool)) { + throttler.init(mockClusterServiceEmpty()); + + throttler.warn(mockedLogger.logger(), "test", new IllegalArgumentException("failed")); + mockedLogger.verify(1, "test"); + mockedLogger.verifyThrowable(1); - try (var throttler = new ThrottlerManager(Settings.EMPTY, threadPool, mockClusterServiceEmpty())) { - throttler.warn(logger, "test", new IllegalArgumentException("failed")); - verify(logger, times(1)).warn(eq("test"), any(Throwable.class)); + mockedLogger.clearInvocations(); - throttler.warn(logger, "a different message", new IllegalArgumentException("failed")); - verify(logger, times(1)).warn(eq("a different message"), any(Throwable.class)); + throttler.warn(mockedLogger.logger(), "a different message", new IllegalArgumentException("failed")); + mockedLogger.verify(1, "a different message"); + mockedLogger.verifyThrowable(1); + mockedLogger.verifyNoMoreInteractions(); } } - public void testStartsNewThrottler_WhenResetIntervalIsChanged() { + public void testStartsNewThrottler_WhenLoggingIntervalIsChanged() { var mockThreadPool = mock(ThreadPool.class); when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class)); - try (var manager = new ThrottlerManager(Settings.EMPTY, mockThreadPool, mockClusterServiceEmpty())) { - var resetInterval = TimeValue.timeValueSeconds(1); + try (var manager = new ThrottlerManager(Settings.EMPTY, mockThreadPool)) { + manager.init(mockClusterServiceEmpty()); + verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), eq(TimeValue.timeValueHours(1)), any()); + + clearInvocations(mockThreadPool); + + var loggingInterval = TimeValue.timeValueSeconds(1); var currentThrottler = manager.getThrottler(); - manager.setResetInterval(resetInterval); - // once for when the throttler is created initially - verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), eq(TimeValue.timeValueDays(1)), any()); - verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), eq(resetInterval), any()); + manager.setLogInterval(loggingInterval); + verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), eq(TimeValue.timeValueSeconds(1)), any()); assertNotSame(currentThrottler, manager.getThrottler()); } } - public void testDoesNotStartNewThrottler_WhenWaitDurationIsChanged() { - var mockThreadPool = mock(ThreadPool.class); - when(mockThreadPool.scheduleWithFixedDelay(any(Runnable.class), any(), any())).thenReturn(mock(Scheduler.Cancellable.class)); + public void testStartsNewThrottler_WhenLoggingIntervalIsChanged_ThreadEmitsPreviousObjectsMessages() { + var mockedLogger = mockLogger(); - try (var manager = new ThrottlerManager(Settings.EMPTY, mockThreadPool, mockClusterServiceEmpty())) { + try (var manager = new ThrottlerManager(Settings.EMPTY, taskQueue.getThreadPool())) { + manager.init(mockClusterServiceEmpty()); + + // first log message should be automatically emitted + manager.warn(mockedLogger.logger(), "test", new IllegalArgumentException("failed")); + mockedLogger.verify(1, "test"); + mockedLogger.verifyThrowable(1); + + mockedLogger.clearInvocations(); + + // This should not be emitted but should increment the counter to 1 + manager.warn(mockedLogger.logger(), "test", new IllegalArgumentException("failed")); + mockedLogger.verifyNever(); + + var loggingInterval = TimeValue.timeValueSeconds(1); var currentThrottler = manager.getThrottler(); + manager.setLogInterval(loggingInterval); + assertNotSame(currentThrottler, manager.getThrottler()); + + mockedLogger.clearInvocations(); + + // This should not be emitted but should increment the counter to 2 + manager.warn(mockedLogger.logger(), "test", new IllegalArgumentException("failed")); + mockedLogger.verifyNever(); + + mockedLogger.clearInvocations(); - var waitDuration = TimeValue.timeValueSeconds(1); - manager.setWaitDuration(waitDuration); - // should only call when initializing the throttler - verify(mockThreadPool, times(1)).scheduleWithFixedDelay(any(Runnable.class), eq(TimeValue.timeValueDays(1)), any()); - assertSame(currentThrottler, manager.getThrottler()); + taskQueue.advanceTime(); + taskQueue.runAllRunnableTasks(); + mockedLogger.verifyContains(1, "test, repeated 2 times"); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java index 77f099557629f..8121d1c31aaa2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/logging/ThrottlerTests.java @@ -7,223 +7,290 @@ package org.elasticsearch.xpack.inference.logging; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogBuilder; import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.core.TimeValue; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.After; import org.junit.Before; +import org.mockito.Mockito; import java.time.Clock; -import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.contains; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class ThrottlerTests extends ESTestCase { - - private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30); - - private ThreadPool threadPool; + private DeterministicTaskQueue taskQueue; @Before public void init() { - threadPool = createThreadPool(inferenceUtilityPool()); - } - - @After - public void shutdown() { - terminate(threadPool); + taskQueue = new DeterministicTaskQueue(); } - public void testWarn_LogsOnlyOnce() { - var logger = mock(Logger.class); + public void testExecute_LogsOnlyOnce() { + var mockedLogger = mockLogger(); try ( var throttler = new Throttler( TimeValue.timeValueDays(1), - TimeValue.timeValueSeconds(10), Clock.fixed(Instant.now(), ZoneId.systemDefault()), - threadPool, + taskQueue.getThreadPool(), new ConcurrentHashMap<>() ) ) { - throttler.execute("test", logger::warn); + throttler.init(); + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verify(1, "test"); + + mockedLogger.clearInvocations(); - verify(logger, times(1)).warn(eq("test")); + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verifyNever(); - throttler.execute("test", logger::warn); - verifyNoMoreInteractions(logger); + mockedLogger.verifyNoMoreInteractions(); } } - public void testWarn_LogsOnce_ThenOnceAfterDuration() { - var logger = mock(Logger.class); - - var now = Clock.systemUTC().instant(); - - var clock = mock(Clock.class); + public void testExecute_LogsOnce_ThenOnceWhenEmittingThreadRuns() { + var mockedLogger = mockLogger(); try ( var throttler = new Throttler( TimeValue.timeValueDays(1), - TimeValue.timeValueSeconds(10), - clock, - threadPool, + Clock.fixed(Instant.now(), ZoneId.systemDefault()), + taskQueue.getThreadPool(), new ConcurrentHashMap<>() ) ) { - when(clock.instant()).thenReturn(now); + throttler.init(); // The first call is always logged - throttler.execute("test", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - verify(logger, times(1)).warn(eq("test"), any(Throwable.class)); - - when(clock.instant()).thenReturn(now.plus(Duration.ofMinutes(1))); - // This call should be allowed because the clock thinks it's after the duration period - throttler.execute("test", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - verify(logger, times(2)).warn(eq("test"), any(Throwable.class)); - - when(clock.instant()).thenReturn(now); - // This call should not be allowed because the clock doesn't think it's pasted the wait period - throttler.execute("test", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - verifyNoMoreInteractions(logger); + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verify(1, "test"); + + mockedLogger.clearInvocations(); + + // This should increment the skipped log count but not actually log anything + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verifyNever(); + + mockedLogger.clearInvocations(); + + // This should log a message with the skip count as 1 + taskQueue.advanceTime(); + taskQueue.runAllRunnableTasks(); + mockedLogger.verifyContains(1, "test, repeated 1 time"); + + mockedLogger.verifyNoMoreInteractions(); } } - public void testWarn_AllowsDifferentMessagesToBeLogged() { - var logger = mock(Logger.class); - - var clock = mock(Clock.class); + public void testExecute_LogsOnce_ThenOnceWhenEmittingThreadRuns_WithException() { + var mockedLogger = mockLogger(); try ( var throttler = new Throttler( TimeValue.timeValueDays(1), - TimeValue.timeValueSeconds(10), - clock, - threadPool, + Clock.fixed(Instant.now(), ZoneId.systemDefault()), + taskQueue.getThreadPool(), new ConcurrentHashMap<>() ) ) { - throttler.execute("test", logger::warn); - verify(logger, times(1)).warn(eq("test")); + throttler.init(); - throttler.execute("a different message", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - verify(logger, times(1)).warn(eq("a different message"), any(Throwable.class)); - } - } + // The first call is always logged + throttler.execute(mockedLogger.logger, Level.WARN, "test", new IllegalArgumentException("failed")); + mockedLogger.verify(1, "test"); + mockedLogger.verifyThrowable(1); - public void testWarn_LogsRepeated1Time() { - var logger = mock(Logger.class); + mockedLogger.clearInvocations(); - var now = Clock.systemUTC().instant(); + // This should increment the skipped log count but not actually log anything + throttler.execute(mockedLogger.logger, Level.WARN, "test", new IllegalArgumentException("failed")); + mockedLogger.verifyNever(); - var clock = mock(Clock.class); + mockedLogger.clearInvocations(); + + // This should log a message with the skip count as 1 + taskQueue.advanceTime(); + taskQueue.runAllRunnableTasks(); + mockedLogger.verifyContains(1, "test, repeated 1 time"); + mockedLogger.verifyThrowable(1); + + mockedLogger.verifyNoMoreInteractions(); + } + } + + public void testExecute_LogsOnce_ThenOnceWhenEmittingThreadRuns_ThenOnceForFirstLog() { + var mockedLogger = mockLogger(); try ( var throttler = new Throttler( TimeValue.timeValueDays(1), - TimeValue.timeValueSeconds(10), - clock, - threadPool, + Clock.fixed(Instant.now(), ZoneId.systemDefault()), + taskQueue.getThreadPool(), new ConcurrentHashMap<>() ) ) { - when(clock.instant()).thenReturn(now); - // first message is allowed - throttler.execute("test", logger::warn); - verify(logger, times(1)).warn(eq("test")); - - when(clock.instant()).thenReturn(now); // don't allow this message because duration hasn't expired - throttler.execute("test", logger::warn); - verify(logger, times(1)).warn(eq("test")); - - when(clock.instant()).thenReturn(now.plus(Duration.ofMinutes(1))); // allow this message by faking expired duration - throttler.execute("test", logger::warn); - verify(logger, times(1)).warn(eq("test, repeated 1 time")); - } - } + throttler.init(); - public void testWarn_LogsRepeated2Times() { - var logger = mock(Logger.class); + // The first call is always logged + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verify(1, "test"); - var now = Clock.systemUTC().instant(); + mockedLogger.clearInvocations(); - var clock = mock(Clock.class); + // This should increment the skipped log count but not actually log anything + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verifyNever(); + + mockedLogger.clearInvocations(); + + // This should log a message with the skip count as 1 + taskQueue.advanceTime(); + taskQueue.runAllRunnableTasks(); + mockedLogger.verifyContains(1, "test, repeated 1 time"); + + mockedLogger.clearInvocations(); + + // Since the thread ran in the code above it will have reset the state so this will be treated as a first message + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verify(1, "test"); + + mockedLogger.verifyNoMoreInteractions(); + } + } + + public void testExecute_AllowsDifferentMessagesToBeLogged() { + var mockedLogger = mockLogger(); try ( var throttler = new Throttler( TimeValue.timeValueDays(1), - TimeValue.timeValueSeconds(10), - clock, - threadPool, + Clock.fixed(Instant.now(), ZoneId.systemDefault()), + taskQueue.getThreadPool(), new ConcurrentHashMap<>() ) ) { - when(clock.instant()).thenReturn(now); - // message allowed because it is the first one - throttler.execute("test", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - verify(logger, times(1)).warn(eq("test"), any(Throwable.class)); - - when(clock.instant()).thenReturn(now); // don't allow these messages because duration hasn't expired - throttler.execute("test", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - throttler.execute("test", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - verify(logger, times(1)).warn(eq("test"), any(Throwable.class)); - - when(clock.instant()).thenReturn(now.plus(Duration.ofMinutes(1))); // allow this message by faking the duration completion - throttler.execute("test", (message) -> logger.warn(message, new IllegalArgumentException("failed"))); - verify(logger, times(1)).warn(eq("test, repeated 2 times"), any(Throwable.class)); + throttler.init(); + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verify(1, "test"); + + mockedLogger.clearInvocations(); + + throttler.execute(mockedLogger.logger, Level.WARN, "a different message"); + mockedLogger.verify(1, "a different message"); } } - public void testResetTask_ClearsInternalsAfterInterval() throws InterruptedException { - var calledClearLatch = new CountDownLatch(1); + public void testExecute_LogsRepeated2Times() { + var mockedLogger = mockLogger(); - var now = Clock.systemUTC().instant(); + try ( + var throttler = new Throttler( + TimeValue.timeValueDays(1), + Clock.fixed(Instant.now(), ZoneId.systemDefault()), + taskQueue.getThreadPool(), + new ConcurrentHashMap<>() + ) + ) { + throttler.init(); - var clock = mock(Clock.class); - when(clock.instant()).thenReturn(now); + // The first call is always logged + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verify(1, "test"); + + // This should increment the skipped log count but not actually log anything + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verifyNoMoreInteractions(); - var concurrentMap = mock(ConcurrentHashMap.class); - doAnswer(invocation -> { - calledClearLatch.countDown(); + // This should increment the skipped log count but not actually log anything + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verifyNoMoreInteractions(); - return Void.TYPE; - }).when(concurrentMap).clear(); + mockedLogger.clearInvocations(); - try (@SuppressWarnings("unchecked") - var ignored = new Throttler(TimeValue.timeValueNanos(1), TimeValue.timeValueSeconds(10), clock, threadPool, concurrentMap)) { - calledClearLatch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + // This should log a message with the skip count as 2 + taskQueue.advanceTime(); + taskQueue.runAllRunnableTasks(); + mockedLogger.verifyContains(1, "test, repeated 2 time"); + + mockedLogger.verifyNoMoreInteractions(); } } public void testClose_DoesNotAllowLoggingAnyMore() { - var logger = mock(Logger.class); + var mockedLogger = mockLogger(); var clock = mock(Clock.class); - var throttler = new Throttler( - TimeValue.timeValueDays(1), - TimeValue.timeValueSeconds(10), - clock, - threadPool, - new ConcurrentHashMap<>() - ); + var throttler = new Throttler(TimeValue.timeValueDays(1), clock, taskQueue.getThreadPool(), new ConcurrentHashMap<>()); throttler.close(); - throttler.execute("test", logger::warn); - verifyNoMoreInteractions(logger); + throttler.execute(mockedLogger.logger, Level.WARN, "test"); + mockedLogger.verifyNoMoreInteractions(); + } + + record MockLogger(Logger logger, LogBuilder logBuilder) { + MockLogger clearInvocations() { + Mockito.clearInvocations(logger); + Mockito.clearInvocations(logBuilder); + + return this; + } + + MockLogger verifyNoMoreInteractions() { + Mockito.verifyNoMoreInteractions(logger); + Mockito.verifyNoMoreInteractions(logBuilder); + + return this; + } + + MockLogger verify(int times, String message) { + Mockito.verify(logger, times(times)).atLevel(eq(Level.WARN)); + Mockito.verify(logBuilder, times(times)).log(eq(message)); + + return this; + } + + MockLogger verifyContains(int times, String message) { + Mockito.verify(logger, times(times)).atLevel(eq(Level.WARN)); + Mockito.verify(logBuilder, times(times)).log(contains(message)); + + return this; + } + + MockLogger verifyNever() { + Mockito.verify(logger, never()).atLevel(eq(Level.WARN)); + Mockito.verify(logBuilder, never()).log(any(String.class)); + Mockito.verify(logBuilder, never()).log(any(Throwable.class)); + + return this; + } + + MockLogger verifyThrowable(int times) { + Mockito.verify(logBuilder, times(times)).withThrowable(any(Throwable.class)); + + return this; + } + } + + static MockLogger mockLogger() { + var builder = mock(LogBuilder.class); + when(builder.withThrowable(any(Throwable.class))).thenReturn(builder); + var logger = mock(Logger.class); + when(logger.atLevel(any(Level.class))).thenReturn(builder); + + return new MockLogger(logger, builder); } }