Skip to content

Commit 9281ceb

Browse files
committed
Make cold start tracking thread-safe.
1 parent d781483 commit 9281ceb

File tree

6 files changed

+40
-22
lines changed

6 files changed

+40
-22
lines changed

powertools-logging/powertools-logging-log4j/src/test/java/org/apache/logging/log4j/layout/template/json/resolver/PowerToolsResolverFactoryTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.nio.file.NoSuchFileException;
2525
import java.nio.file.Paths;
2626
import java.nio.file.StandardOpenOption;
27+
import java.util.concurrent.atomic.AtomicBoolean;
2728

2829
import org.junit.jupiter.api.AfterEach;
2930
import org.junit.jupiter.api.BeforeEach;
@@ -48,7 +49,7 @@ void setUp() throws IllegalAccessException, IOException {
4849
MDC.clear();
4950
// Reset cold start state
5051
writeStaticField(LambdaHandlerProcessor.class, "isColdStart", null, true);
51-
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", false, true);
52+
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", new AtomicBoolean(false), true);
5253

5354
context = new TestLambdaContext();
5455
// Make sure file is cleaned up before running tests

powertools-logging/powertools-logging-logback/src/test/java/software/amazon/lambda/powertools/logging/internal/LambdaEcsEncoderTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.nio.file.NoSuchFileException;
2626
import java.nio.file.Paths;
2727
import java.nio.file.StandardOpenOption;
28+
import java.util.concurrent.atomic.AtomicBoolean;
2829

2930
import org.junit.jupiter.api.AfterEach;
3031
import org.junit.jupiter.api.BeforeEach;
@@ -57,7 +58,7 @@ void setUp() throws IllegalAccessException, IOException {
5758
MDC.clear();
5859
// Reset cold start state
5960
writeStaticField(LambdaHandlerProcessor.class, "isColdStart", null, true);
60-
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", false, true);
61+
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", new AtomicBoolean(false), true);
6162

6263
context = new TestLambdaContext();
6364
// Make sure file is cleaned up before running tests

powertools-logging/powertools-logging-logback/src/test/java/software/amazon/lambda/powertools/logging/internal/LambdaJsonEncoderTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
import java.util.Collections;
4444
import java.util.Date;
4545
import java.util.TimeZone;
46+
import java.util.concurrent.atomic.AtomicBoolean;
47+
4648
import org.junit.jupiter.api.AfterEach;
4749
import org.junit.jupiter.api.Assertions;
4850
import org.junit.jupiter.api.BeforeEach;
@@ -76,7 +78,7 @@ void setUp() throws IllegalAccessException, IOException {
7678
MDC.clear();
7779
// Reset cold start state
7880
writeStaticField(LambdaHandlerProcessor.class, "isColdStart", null, true);
79-
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", false, true);
81+
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", new AtomicBoolean(false), true);
8082

8183
context = new TestLambdaContext();
8284
// Make sure file is cleaned up before running tests

powertools-logging/src/main/java/software/amazon/lambda/powertools/logging/PowertoolsLogging.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.Arrays;
3030
import java.util.Locale;
3131
import java.util.Random;
32+
import java.util.concurrent.atomic.AtomicBoolean;
3233

3334
import org.slf4j.Logger;
3435
import org.slf4j.LoggerFactory;
@@ -65,7 +66,7 @@
6566
public final class PowertoolsLogging {
6667
private static final Logger LOG = LoggerFactory.getLogger(PowertoolsLogging.class);
6768
private static final ThreadLocal<Random> SAMPLER = ThreadLocal.withInitial(Random::new);
68-
private static volatile boolean hasBeenInitialized = false;
69+
private static AtomicBoolean hasBeenInitialized = new AtomicBoolean(false);
6970

7071
static {
7172
initializeLogLevel();
@@ -176,16 +177,14 @@ public static void initializeLogging(Context context, String correlationIdPath,
176177
* configures sampling rate for DEBUG logging, and optionally extracts
177178
* correlation ID from the event.
178179
*
180+
* This method is tread-safe.
181+
*
179182
* @param context the Lambda context provided by AWS Lambda runtime
180183
* @param samplingRate sampling rate for DEBUG logging (0.0 to 1.0)
181184
* @param correlationIdPath JSON path to extract correlation ID from event (can be null)
182185
* @param event the Lambda event object (required if correlationIdPath is provided)
183186
*/
184187
public static void initializeLogging(Context context, double samplingRate, String correlationIdPath, Object event) {
185-
if (hasBeenInitialized) {
186-
coldStartDone();
187-
}
188-
hasBeenInitialized = true;
189188

190189
addLambdaContextToLoggingContext(context);
191190
setLogLevelBasedOnSamplingRate(samplingRate);
@@ -196,12 +195,17 @@ public static void initializeLogging(Context context, double samplingRate, Strin
196195
}
197196
}
198197

199-
private static void addLambdaContextToLoggingContext(Context context) {
198+
// Synchronized since isColdStart() is a globally managed constant in LambdaHandlerProcessor
199+
private static synchronized void addLambdaContextToLoggingContext(Context context) {
200200
if (context != null) {
201201
PowertoolsLoggedFields.setValuesFromLambdaContext(context).forEach(MDC::put);
202-
MDC.put(FUNCTION_COLD_START.getName(), isColdStart() ? "true" : "false");
203-
MDC.put(SERVICE.getName(), serviceName());
204202
}
203+
204+
MDC.put(FUNCTION_COLD_START.getName(), isColdStart() ? "true" : "false");
205+
if (hasBeenInitialized.compareAndSet(false, true)) {
206+
coldStartDone();
207+
}
208+
MDC.put(SERVICE.getName(), serviceName());
205209
}
206210

207211
private static void setLogLevelBasedOnSamplingRate(double samplingRate) {

powertools-logging/src/test/java/software/amazon/lambda/powertools/logging/PowertoolsLoggingTest.java

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.nio.file.Paths;
2828
import java.nio.file.StandardOpenOption;
2929
import java.util.Map;
30+
import java.util.concurrent.atomic.AtomicBoolean;
3031

3132
import org.junit.jupiter.api.AfterEach;
3233
import org.junit.jupiter.api.BeforeEach;
@@ -68,7 +69,7 @@ void setUp() throws IllegalAccessException, IOException {
6869

6970
// Reset cold start state
7071
writeStaticField(LambdaHandlerProcessor.class, "isColdStart", null, true);
71-
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", false, true);
72+
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", new AtomicBoolean(false), true);
7273

7374
try {
7475
FileChannel.open(Paths.get("target/logfile.json"), StandardOpenOption.WRITE).truncate(0).close();
@@ -313,7 +314,7 @@ void initializeLogging_withEnvVarAndParameter_shouldUseEnvVarPrecedence() throws
313314
@Test
314315
void initializeLogging_calledTwice_shouldMarkColdStartDoneOnSecondCall() throws IllegalAccessException {
315316
// GIVEN
316-
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", false, true);
317+
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", new AtomicBoolean(false), true);
317318

318319
// WHEN - First call
319320
PowertoolsLogging.initializeLogging(context);
@@ -371,21 +372,24 @@ void initializeLogging_concurrentCalls_shouldBeThreadSafe() throws InterruptedEx
371372
int threadCount = 10;
372373
Thread[] threads = new Thread[threadCount];
373374
String[] samplingRates = new String[threadCount];
375+
boolean[] coldStarts = new boolean[threadCount];
374376
boolean[] success = new boolean[threadCount];
375377

376378
// WHEN - Multiple threads call initializeLogging with alternating sampling rates
377379
for (int i = 0; i < threadCount; i++) {
378380
final int threadIndex = i;
379381
final double samplingRate = (i % 2 == 0) ? 1.0 : 0.0; // Alternate between 1.0 and 0.0
380-
382+
381383
threads[i] = new Thread(() -> {
382384
try {
383385
PowertoolsLogging.initializeLogging(context, samplingRate);
384-
385-
// Capture the sampling rate set in MDC (thread-local)
386+
387+
// Capture the sampling rate and cold start values set in MDC (thread-local)
386388
samplingRates[threadIndex] = MDC.get(PowertoolsLoggedFields.SAMPLING_RATE.getName());
389+
coldStarts[threadIndex] = Boolean
390+
.parseBoolean(MDC.get(PowertoolsLoggedFields.FUNCTION_COLD_START.getName()));
387391
success[threadIndex] = true;
388-
392+
389393
// Clean up thread-local state
390394
PowertoolsLogging.clearState(true);
391395
} catch (Exception e) {
@@ -408,12 +412,17 @@ void initializeLogging_concurrentCalls_shouldBeThreadSafe() throws InterruptedEx
408412
for (boolean result : success) {
409413
assertThat(result).isTrue();
410414
}
411-
412-
// THEN - Each thread should have its own sampling rate in MDC
415+
416+
// THEN - Each thread should have its own sampling rate in MDC and exactly one invocation was a cold start
417+
int coldStartCount = 0;
413418
for (int i = 0; i < threadCount; i++) {
414-
String expectedRate = (i % 2 == 0) ? "1.0" : "0.0";
415-
assertThat(samplingRates[i]).as("Thread %d should have sampling rate %s", i, expectedRate).isEqualTo(expectedRate);
419+
String expectedSamplingRate = (i % 2 == 0) ? "1.0" : "0.0";
420+
assertThat(samplingRates[i]).as("Thread %d should have sampling rate %s", i, expectedSamplingRate)
421+
.isEqualTo(expectedSamplingRate);
422+
423+
coldStartCount += coldStarts[i] ? 1 : 0;
416424
}
425+
assertThat(coldStartCount).isEqualTo(1);
417426
}
418427

419428
private void reinitializeLogLevel() {

powertools-logging/src/test/java/software/amazon/lambda/powertools/logging/internal/LambdaLoggingAspectTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import java.util.Collections;
4040
import java.util.List;
4141
import java.util.Map;
42+
import java.util.concurrent.atomic.AtomicBoolean;
4243

4344
import org.junit.jupiter.api.AfterEach;
4445
import org.junit.jupiter.api.BeforeEach;
@@ -99,7 +100,7 @@ void setUp() throws IllegalAccessException, IOException {
99100

100101
// Reset cold start state
101102
writeStaticField(LambdaHandlerProcessor.class, "isColdStart", null, true);
102-
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", false, true);
103+
writeStaticField(PowertoolsLogging.class, "hasBeenInitialized", new AtomicBoolean(false), true);
103104

104105
context = new TestLambdaContext();
105106
requestHandler = new PowertoolsLogEnabled();

0 commit comments

Comments
 (0)