Skip to content

Commit 89bde8f

Browse files
committed
Fix
1 parent 949ee8c commit 89bde8f

File tree

2 files changed

+131
-33
lines changed

2 files changed

+131
-33
lines changed

dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ public class AppSecRequestContext implements DataBundle, Closeable {
141141
private boolean respDataPublished;
142142
private boolean pathParamsPublished;
143143
private volatile Map<String, Object> derivatives;
144+
private final Object derivativesLock = new Object();
144145

145146
private final AtomicBoolean rateLimited = new AtomicBoolean(false);
146147
private volatile boolean throttled;
@@ -649,9 +650,11 @@ public void close() {
649650
requestHeaders.clear();
650651
responseHeaders.clear();
651652
persistentData.clear();
652-
if (derivatives != null) {
653-
derivatives.clear();
654-
derivatives = null;
653+
synchronized (derivativesLock) {
654+
if (derivatives != null) {
655+
derivatives.clear();
656+
derivatives = null;
657+
}
655658
}
656659
}
657660
}
@@ -743,10 +746,7 @@ public void reportDerivatives(Map<String, Object> data) {
743746
log.debug("Reporting derivatives: {}", data);
744747
if (data == null || data.isEmpty()) return;
745748

746-
// Store raw derivatives
747-
if (derivatives == null) {
748-
derivatives = new HashMap<>();
749-
}
749+
Map<String, Object> newDerivatives = new LinkedHashMap<>();
750750

751751
// Process each attribute according to the specification
752752
for (Map.Entry<String, Object> entry : data.entrySet()) {
@@ -762,7 +762,7 @@ public void reportDerivatives(Map<String, Object> data) {
762762
Object literalValue = config.get("value");
763763
if (literalValue != null) {
764764
// Preserve the original type - don't convert to string
765-
derivatives.put(attributeKey, literalValue);
765+
newDerivatives.put(attributeKey, literalValue);
766766
log.debug(
767767
"Added literal attribute: {} = {} (type: {})",
768768
attributeKey,
@@ -781,16 +781,25 @@ else if (config.containsKey("address")) {
781781
Object extractedValue = extractValueFromRequestData(address, keyPath, transformers);
782782
if (extractedValue != null) {
783783
// For extracted values, convert to string as they come from request data
784-
derivatives.put(attributeKey, extractedValue.toString());
784+
newDerivatives.put(attributeKey, extractedValue.toString());
785785
log.debug("Added extracted attribute: {} = {}", attributeKey, extractedValue);
786786
}
787787
}
788788
} else {
789789
// Handle plain string/numeric values
790-
derivatives.put(attributeKey, attributeConfig);
790+
newDerivatives.put(attributeKey, attributeConfig);
791791
log.debug("Added direct attribute: {} = {}", attributeKey, attributeConfig);
792792
}
793793
}
794+
795+
if (!newDerivatives.isEmpty()) {
796+
synchronized (derivativesLock) {
797+
if (derivatives == null) {
798+
derivatives = new HashMap<>();
799+
}
800+
derivatives.putAll(newDerivatives);
801+
}
802+
}
794803
}
795804

796805
/**
@@ -943,40 +952,48 @@ public boolean commitDerivatives(TraceSegment traceSegment) {
943952
return false;
944953
}
945954

955+
Map<String, Object> derivativesSnapshot;
956+
synchronized (derivativesLock) {
957+
if (derivatives == null || derivatives.isEmpty()) {
958+
derivatives = null;
959+
return true;
960+
}
961+
derivativesSnapshot = new LinkedHashMap<>(derivatives);
962+
derivatives = null;
963+
}
964+
946965
// Process and commit derivatives directly
947-
if (derivatives != null && !derivatives.isEmpty()) {
948-
for (Map.Entry<String, Object> entry : derivatives.entrySet()) {
949-
String key = entry.getKey();
950-
Object value = entry.getValue();
951-
952-
// Handle different value types
953-
if (value instanceof Number) {
954-
traceSegment.setTagTop(key, (Number) value);
955-
} else if (value instanceof String) {
956-
// Try to parse as numeric, otherwise use as string
957-
Number parsedNumber = convertToNumericAttribute((String) value);
958-
if (parsedNumber != null) {
959-
traceSegment.setTagTop(key, parsedNumber);
960-
} else {
961-
traceSegment.setTagTop(key, value);
962-
}
963-
} else if (value instanceof Boolean) {
964-
traceSegment.setTagTop(key, value);
966+
for (Map.Entry<String, Object> entry : derivativesSnapshot.entrySet()) {
967+
String key = entry.getKey();
968+
Object value = entry.getValue();
969+
970+
// Handle different value types
971+
if (value instanceof Number) {
972+
traceSegment.setTagTop(key, (Number) value);
973+
} else if (value instanceof String) {
974+
// Try to parse as numeric, otherwise use as string
975+
Number parsedNumber = convertToNumericAttribute((String) value);
976+
if (parsedNumber != null) {
977+
traceSegment.setTagTop(key, parsedNumber);
965978
} else {
966-
// Convert other types to string
967-
traceSegment.setTagTop(key, value.toString());
979+
traceSegment.setTagTop(key, value);
968980
}
981+
} else if (value instanceof Boolean) {
982+
traceSegment.setTagTop(key, value);
983+
} else {
984+
// Convert other types to string
985+
traceSegment.setTagTop(key, value.toString());
969986
}
970987
}
971988

972-
// Clear all attribute maps
973-
derivatives = null;
974989
return true;
975990
}
976991

977992
// Mainly used for testing and logging
978993
Set<String> getDerivativeKeys() {
979-
return derivatives == null ? emptySet() : new HashSet<>(derivatives.keySet());
994+
synchronized (derivativesLock) {
995+
return derivatives == null ? emptySet() : new HashSet<>(derivatives.keySet());
996+
}
980997
}
981998

982999
public boolean isThrottled(RateLimiter rateLimiter) {

dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/gateway/AppSecRequestContextSpecification.groovy

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@ import com.squareup.moshi.JsonAdapter
1111
import com.squareup.moshi.Moshi
1212
import com.squareup.moshi.Types
1313
import datadog.trace.api.Config
14+
import datadog.trace.api.internal.TraceSegment
1415
import datadog.trace.api.telemetry.LogCollector
1516
import datadog.trace.test.logging.TestLogCollector
1617
import datadog.trace.test.util.DDSpecification
1718
import datadog.trace.util.stacktrace.StackTraceEvent
1819
import datadog.trace.util.stacktrace.StackTraceFrame
20+
import java.util.concurrent.ConcurrentLinkedQueue
21+
import java.util.concurrent.CountDownLatch
22+
import java.util.concurrent.TimeUnit
1923
import okio.Okio
2024

2125
class AppSecRequestContextSpecification extends DDSpecification {
@@ -397,6 +401,83 @@ class AppSecRequestContextSpecification extends DDSpecification {
397401
keys.contains("_dd.appsec.s.res.content_type_upper")
398402
}
399403

404+
void testCommitDerivativesNormalizesAttributeTypes() {
405+
given:
406+
def context = new AppSecRequestContext()
407+
context.reportDerivatives([
408+
numeric: [value: "42"],
409+
string: [value: "value"],
410+
bool: true,
411+
list: [value: [1, 2]]
412+
])
413+
def traceSegment = Mock(TraceSegment)
414+
415+
when:
416+
context.commitDerivatives(traceSegment)
417+
418+
then:
419+
1 * traceSegment.setTagTop("numeric", 42L)
420+
1 * traceSegment.setTagTop("string", "value")
421+
1 * traceSegment.setTagTop("bool", true)
422+
1 * traceSegment.setTagTop("list", "[1, 2]")
423+
0 * _
424+
context.getDerivativeKeys().isEmpty()
425+
}
426+
427+
void testCommitDerivativesHandlesConcurrentUpdates() {
428+
given:
429+
def context = new AppSecRequestContext()
430+
def start = new CountDownLatch(1)
431+
def done = new CountDownLatch(6)
432+
def errors = new ConcurrentLinkedQueue<Throwable>()
433+
def traceSegment = TraceSegment.NoOp.INSTANCE
434+
435+
def reporters = []
436+
for (int n = 0; n < 3; n++) {
437+
reporters.add(Thread.start({
438+
try {
439+
start.await()
440+
for (int idx = 0; idx < 100; idx++) {
441+
def key = "key-${Thread.currentThread().id}-${idx}" as String
442+
def val = [value: idx as String]
443+
context.reportDerivatives([(key): val])
444+
}
445+
} catch (Throwable t) {
446+
errors.add(t)
447+
} finally {
448+
done.countDown()
449+
}
450+
}))
451+
}
452+
453+
def committers = []
454+
for (int n = 0; n < 3; n++) {
455+
committers.add(Thread.start({
456+
try {
457+
start.await()
458+
for (int i = 0; i < 100; i++) {
459+
context.commitDerivatives(traceSegment)
460+
}
461+
} catch (Throwable t) {
462+
errors.add(t)
463+
} finally {
464+
done.countDown()
465+
}
466+
}))
467+
}
468+
469+
when:
470+
start.countDown()
471+
def completed = done.await(10, TimeUnit.SECONDS)
472+
(reporters + committers)*.join()
473+
context.commitDerivatives(traceSegment)
474+
475+
then:
476+
completed
477+
errors.isEmpty()
478+
context.getDerivativeKeys().isEmpty()
479+
}
480+
400481
def "test attribute handling with unknown address"() {
401482
given:
402483
def context = new AppSecRequestContext()

0 commit comments

Comments
 (0)