Skip to content

Commit ff8ee85

Browse files
Reset iast request context on root span published (#7969)
1 parent 4dfa404 commit ff8ee85

File tree

11 files changed

+124
-47
lines changed

11 files changed

+124
-47
lines changed

dd-java-agent/agent-iast/src/main/java/com/datadog/iast/IastGlobalContext.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.datadog.iast.taint.TaintedMap;
44
import com.datadog.iast.taint.TaintedObjects;
55
import datadog.trace.api.iast.IastContext;
6+
import java.io.IOException;
67
import java.util.concurrent.TimeUnit;
78
import javax.annotation.Nonnull;
89
import javax.annotation.Nullable;
@@ -22,6 +23,9 @@ public TaintedObjects getTaintedObjects() {
2223
return taintedObjects;
2324
}
2425

26+
@Override
27+
public void close() throws IOException {}
28+
2529
public static class Provider extends IastContext.Provider {
2630

2731
// (16384 * 4) buckets: approx 256K

dd-java-agent/agent-iast/src/main/java/com/datadog/iast/IastOptOutContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@
22

33
import com.datadog.iast.taint.TaintedObjects;
44
import datadog.trace.api.iast.IastContext;
5+
import java.io.IOException;
56
import javax.annotation.Nonnull;
67
import javax.annotation.Nullable;
78
import org.jetbrains.annotations.NotNull;
89

910
public class IastOptOutContext implements IastContext {
1011

12+
@Nonnull
1113
@SuppressWarnings("unchecked")
1214
@NotNull
1315
@Override
1416
public TaintedObjects getTaintedObjects() {
1517
return TaintedObjects.NoOp.INSTANCE;
1618
}
1719

20+
@Override
21+
public void close() throws IOException {}
22+
1823
public static class Provider extends IastContext.Provider {
1924

2025
final IastContext optOutContext = new IastOptOutContext();

dd-java-agent/agent-iast/src/main/java/com/datadog/iast/IastRequestContext.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import datadog.trace.api.iast.telemetry.IastMetricCollector.HasMetricCollector;
1616
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
1717
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
18+
import java.io.IOException;
1819
import java.util.Queue;
1920
import java.util.concurrent.ArrayBlockingQueue;
21+
import java.util.function.Consumer;
2022
import javax.annotation.Nonnull;
2123
import javax.annotation.Nullable;
2224

@@ -27,6 +29,7 @@ public class IastRequestContext implements IastContext, HasMetricCollector {
2729
private final VulnerabilityBatch vulnerabilityBatch;
2830
private final OverheadContext overheadContext;
2931
private TaintedObjects taintedObjects;
32+
@Nullable private Consumer<IastContext> release;
3033
@Nullable private IastMetricCollector collector;
3134
@Nullable private volatile String strictTransportSecurity;
3235
@Nullable private volatile String xContentTypeOptions;
@@ -124,12 +127,20 @@ public void setTaintedObjects(@Nonnull final TaintedObjects taintedObjects) {
124127
this.taintedObjects = taintedObjects;
125128
}
126129

130+
@Override
131+
public void close() throws IOException {
132+
if (release != null) {
133+
release.accept(this);
134+
release = null;
135+
}
136+
}
137+
127138
public static class Provider extends IastContext.Provider {
128139

129140
// 16384 buckets: approx 64K
130141
static final int MAP_SIZE = TaintedMap.DEFAULT_CAPACITY;
131142

132-
final Queue<TaintedObjects> pool =
143+
private final Queue<TaintedObjects> pool =
133144
new ArrayBlockingQueue<>(
134145
Math.max(
135146
Config.get().getIastMaxConcurrentRequests(), DEFAULT_IAST_MAX_CONCURRENT_REQUESTS));
@@ -154,19 +165,28 @@ public IastContext buildRequestContext() {
154165
if (taintedObjects == null) {
155166
taintedObjects = TaintedObjects.build(TaintedMap.build(MAP_SIZE));
156167
}
157-
return new IastRequestContext(taintedObjects);
168+
final IastRequestContext ctx = new IastRequestContext(taintedObjects);
169+
ctx.release = this::releaseRequestContext;
170+
return ctx;
158171
}
159172

160173
@SuppressWarnings("unchecked")
161174
@Override
162175
public void releaseRequestContext(@Nonnull final IastContext context) {
163-
final TaintedObjects taintedObjects = context.getTaintedObjects();
176+
final IastRequestContext iastCtx = (IastRequestContext) context;
177+
178+
// reset tainted objects map
179+
final TaintedObjects taintedObjects = iastCtx.getTaintedObjects();
164180
taintedObjects.clear();
165-
// add the root instance to the pool
166-
if (taintedObjects instanceof Wrapper) {
167-
pool.offer(((Wrapper<TaintedObjects>) taintedObjects).unwrap());
168-
} else {
169-
pool.offer(taintedObjects);
181+
182+
// return to pool and update internal ref
183+
final TaintedObjects unwrapped =
184+
taintedObjects instanceof Wrapper
185+
? ((Wrapper<TaintedObjects>) taintedObjects).unwrap()
186+
: taintedObjects;
187+
if (unwrapped != TaintedObjects.NoOp.INSTANCE) {
188+
pool.offer(unwrapped);
189+
iastCtx.setTaintedObjects(TaintedObjects.NoOp.INSTANCE);
170190
}
171191
}
172192
}

dd-java-agent/agent-iast/src/test/groovy/com/datadog/iast/IastRequestContextTest.groovy

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.datadog.iast
22

33
import com.datadog.iast.model.Range
44
import com.datadog.iast.taint.TaintedObjects
5+
import datadog.trace.api.Config
56
import datadog.trace.api.gateway.RequestContext
67
import datadog.trace.api.gateway.RequestContextSlot
78
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
@@ -53,6 +54,15 @@ class IastRequestContextTest extends DDSpecification {
5354

5455
then:
5556
1 * tracer.activeSpan() >> span
57+
1 * span.getRequestContext() >> null
58+
resolved == null
59+
60+
when:
61+
resolved = provider.resolve()
62+
63+
then:
64+
1 * tracer.activeSpan() >> span
65+
1 * span.getRequestContext() >> reqCtx
5666
resolved === initialCtx
5767
}
5868

@@ -72,5 +82,42 @@ class IastRequestContextTest extends DDSpecification {
7282
then:
7383
to.count() == 0
7484
provider.pool.size() == 1
85+
86+
when:
87+
final maxPoolSize = Config.get().getIastMaxConcurrentRequests()
88+
final list = (1..2 * maxPoolSize).collect {
89+
provider.buildRequestContext()
90+
}
91+
92+
then:
93+
provider.pool.size() == 0
94+
95+
when:
96+
list.each { provider.releaseRequestContext(it) }
97+
98+
then:
99+
provider.pool.size() == maxPoolSize
100+
}
101+
102+
void 'ensure that the context releases all tainted objects on close'() {
103+
setup:
104+
final ctx = provider.buildRequestContext() as IastRequestContext
105+
106+
when:
107+
ctx.withCloseable {
108+
it.taintedObjects.taint(UUID.randomUUID(), [] as Range[])
109+
}
110+
111+
then:
112+
ctx.taintedObjects.count() == 0
113+
114+
when:
115+
ctx.withCloseable {
116+
it.taintedObjects.taint(UUID.randomUUID(), [] as Range[])
117+
assert it.taintedObjects.count() == 0
118+
}
119+
120+
then:
121+
ctx.taintedObjects.count() == 0
75122
}
76123
}

dd-java-agent/agent-iast/src/test/groovy/com/datadog/iast/IastSystemTest.groovy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class IastSystemTest extends DDSpecification {
7373

7474
then:
7575
1 * iastContext.getTaintedObjects()
76+
1 * iastContext.setTaintedObjects(_)
7677
1 * iastContext.getMetricCollector()
7778
1 * traceSegment.setTagTop('_dd.iast.enabled', 1)
7879
1 * iastContext.getxContentTypeOptions() >> 'nosniff'

dd-java-agent/agent-iast/src/testFixtures/groovy/com/datadog/iast/test/IastAgentTestRunner.groovy

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import datadog.trace.agent.test.AgentTestRunner
66
import datadog.trace.agent.tooling.bytebuddy.iast.TaintableVisitor
77
import datadog.trace.api.gateway.CallbackProvider
88
import datadog.trace.api.gateway.Events
9-
import datadog.trace.api.gateway.Flow
109
import datadog.trace.api.gateway.RequestContextSlot
1110
import datadog.trace.api.iast.IastContext
1211
import datadog.trace.api.iast.SourceTypes
@@ -15,7 +14,6 @@ import datadog.trace.bootstrap.instrumentation.api.AgentTracer
1514
import datadog.trace.bootstrap.instrumentation.api.TagContext
1615
import datadog.trace.core.DDSpan
1716

18-
import java.util.function.Supplier
1917

2018
class IastAgentTestRunner extends AgentTestRunner implements IastRequestContextPreparationTrait {
2119
public static final EMPTY_SOURCE = new Source(SourceTypes.NONE, '', '')
@@ -40,25 +38,18 @@ class IastAgentTestRunner extends AgentTestRunner implements IastRequestContextP
4038
IastContext.Provider.get().taintedObjects
4139
}
4240

43-
protected TaintedObjectCollection getLocalTaintedObjectCollection() {
44-
new TaintedObjectCollection(localTaintedObjects)
45-
}
46-
47-
protected TaintedObjectCollection getTaintedObjectCollection(DDSpan span) {
48-
final IastContext ctx = span.getRequestContext().getData(RequestContextSlot.IAST)
49-
return new TaintedObjectCollection(ctx.getTaintedObjects())
50-
}
51-
5241
protected DDSpan runUnderIastTrace(Closure cl) {
5342
CallbackProvider iastCbp = TEST_TRACER.getCallbackProvider(RequestContextSlot.IAST)
54-
Supplier<Flow<Object>> reqStartCb = iastCbp.getCallback(Events.EVENTS.requestStarted())
43+
def reqStartCb = iastCbp.getCallback(Events.EVENTS.requestStarted())
44+
def reqEndCb = iastCbp.getCallback(Events.EVENTS.requestEnded())
5545

5646
def iastCtx = reqStartCb.get().result
5747
def ddctx = new TagContext().withRequestContextDataIast(iastCtx)
5848
AgentSpan span = TEST_TRACER.startSpan("test", "test-iast-span", ddctx)
5949
try {
6050
AgentTracer.activateSpan(span).withCloseable cl
6151
} finally {
52+
reqEndCb.apply(span.requestContext, span)
6253
span.finish()
6354
}
6455

dd-java-agent/agent-iast/src/testFixtures/groovy/com/datadog/iast/test/IastRequestContextPreparationTrait.groovy

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,27 @@ import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.get
1919
trait IastRequestContextPreparationTrait {
2020

2121
static void iastSystemSetup(Closure reqEndAction = null) {
22-
def ss = AgentTracer.get().getSubscriptionService(RequestContextSlot.IAST)
22+
final tracer = AgentTracer.get()
23+
def ss = tracer.getSubscriptionService(RequestContextSlot.IAST)
2324
IastSystem.start(ss, new NoopOverheadController())
2425

2526
EventType<Supplier<Flow<Object>>> requestStarted = Events.get().requestStarted()
2627
EventType<BiFunction<RequestContext, IGSpanInfo, Flow<Void>>> requestEnded =
2728
Events.get().requestEnded()
2829

2930
// get original callbacks
30-
CallbackProvider provider = AgentTracer.get().getCallbackProvider(RequestContextSlot.IAST)
31+
CallbackProvider provider = tracer.getCallbackProvider(RequestContextSlot.IAST)
3132
def origRequestStarted = provider.getCallback(requestStarted)
3233
def origRequestEnded = provider.getCallback(requestEnded)
3334

3435
// wrap the original IG callbacks
3536
ss.reset()
3637
ss.registerCallback(requestStarted, new TaintedMapSaveStrongRefsRequestStarted(orig: origRequestStarted))
37-
if (reqEndAction != null) {
38-
ss.registerCallback(requestEnded, new TaintedMapSavingRequestEnded(
39-
original: origRequestEnded, beforeAction: reqEndAction))
40-
}
38+
ss.registerCallback(
39+
requestEnded,
40+
reqEndAction == null
41+
? origRequestEnded
42+
: new TaintedMapSavingRequestEnded(original: origRequestEnded, beforeAction: reqEndAction))
4143
}
4244

4345
static void iastSystemCleanup() {

dd-java-agent/instrumentation/kafka-clients-0.11/src/iastLatestDepTest3/groovy/iast/KafkaIastDeserializerTest.groovy

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package iast
22

33
import com.datadog.iast.propagation.PropagationModuleImpl
4-
import com.datadog.iast.test.IastAgentTestRunner
4+
import com.datadog.iast.test.IastRequestTestRunner
55
import datadog.trace.api.iast.InstrumentationBridge
66
import datadog.trace.api.iast.SourceTypes
77
import org.apache.kafka.common.header.internals.RecordHeaders
@@ -16,7 +16,7 @@ import java.nio.ByteBuffer
1616
import static org.hamcrest.CoreMatchers.instanceOf
1717
import static org.hamcrest.core.IsEqual.equalTo
1818

19-
class KafkaIastDeserializerTest extends IastAgentTestRunner {
19+
class KafkaIastDeserializerTest extends IastRequestTestRunner {
2020

2121
private static final int BUFF_OFFSET = 10
2222

@@ -31,13 +31,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
3131
final deserializer = new StringDeserializer()
3232

3333
when:
34-
final span = runUnderIastTrace {
34+
runUnderIastTrace {
3535
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
3636
test.method.deserialize(deserializer, "test", payload)
3737
}
3838

3939
then:
40-
final to = getTaintedObjectCollection(span)
40+
final to = finReqTaintedObjects
4141
to.hasTaintedObject {
4242
value('Hello World!')
4343
range(0, 12, source(origin))
@@ -58,13 +58,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
5858
final deserializer = new ByteArrayDeserializer()
5959

6060
when:
61-
final span = runUnderIastTrace {
61+
runUnderIastTrace {
6262
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
6363
test.method.deserialize(deserializer, "test", payload)
6464
}
6565

6666
then:
67-
final to = getTaintedObjectCollection(span)
67+
final to = finReqTaintedObjects
6868
to.hasTaintedObject {
6969
value(equalTo(payload))
7070
range(0, Integer.MAX_VALUE, source(origin))
@@ -85,13 +85,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
8585
final deserializer = new ByteBufferDeserializer()
8686

8787
when:
88-
final span = runUnderIastTrace {
88+
runUnderIastTrace {
8989
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
9090
test.method.deserialize(deserializer, "test", payload)
9191
}
9292

9393
then:
94-
final to = getTaintedObjectCollection(span)
94+
final to = finReqTaintedObjects
9595
to.hasTaintedObject {
9696
value(instanceOf(ByteBuffer))
9797
range(0, Integer.MAX_VALUE, source(origin))
@@ -113,13 +113,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
113113
final deserializer = new JsonDeserializer(TestBean)
114114

115115
when:
116-
final span = runUnderIastTrace {
116+
runUnderIastTrace {
117117
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
118118
test.method.deserialize(deserializer, 'test', payload)
119119
}
120120

121121
then:
122-
final to = getTaintedObjectCollection(span)
122+
final to = finReqTaintedObjects
123123
to.hasTaintedObject {
124124
value(instanceOf(TestBean))
125125
range(0, Integer.MAX_VALUE, source(origin as byte))

0 commit comments

Comments
 (0)