Skip to content

Commit e92a04f

Browse files
Implemented time-based sampling per endpoint
1 parent 11069e2 commit e92a04f

File tree

7 files changed

+218
-84
lines changed

7 files changed

+218
-84
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package com.datadog.appsec.api.security;
2+
3+
import java.util.Deque;
4+
import java.util.Map;
5+
import java.util.concurrent.ConcurrentHashMap;
6+
import java.util.concurrent.ConcurrentLinkedDeque;
7+
8+
/**
9+
* The ApiAccessTracker class provides a mechanism to track API access events, managing them within
10+
* a specified capacity limit. Each event is associated with a unique combination of route, method,
11+
* and status code, which is used to generate a unique key for tracking access timestamps.
12+
*
13+
* <p>Usage: - When an API access event occurs, the `updateApiAccessIfExpired` method is called with
14+
* the route, method, and status code of the API request. - If the access event for the given
15+
* parameters is new or has expired (based on the expirationTimeInMs threshold), the event's
16+
* timestamp is updated, effectively moving the event to the end of the tracking list. - If the
17+
* tracker's capacity is reached, the oldest event is automatically removed to make room for new
18+
* events. - This mechanism ensures that the tracker always contains the most recent access events
19+
* within the specified capacity limit, with older, less relevant events being discarded.
20+
*/
21+
public class ApiAccessTracker {
22+
private static final int INTERVAL_SECONDS = 30;
23+
private static final int MAX_SIZE = 4096;
24+
private final Map<Long, Long> apiAccessMap; // Map<hash, timestamp>
25+
private final Deque<Long> apiAccessQueue; // hashes ordered by access time
26+
private final long expirationTimeInMs;
27+
private final int capacity;
28+
29+
public ApiAccessTracker() {
30+
this(MAX_SIZE, INTERVAL_SECONDS * 1000);
31+
}
32+
33+
public ApiAccessTracker(int capacity, long expirationTimeInMs) {
34+
this.capacity = capacity;
35+
this.expirationTimeInMs = expirationTimeInMs;
36+
this.apiAccessMap = new ConcurrentHashMap<>();
37+
this.apiAccessQueue = new ConcurrentLinkedDeque<>();
38+
}
39+
40+
/**
41+
* Updates the API access log with the given route, method, and status code. If the record exists
42+
* and is outdated, it is updated by moving to the end of the list. If the record does not exist,
43+
* a new record is added. If the capacity limit is reached, the oldest record is removed. Returns
44+
* true if the record was updated or added, false otherwise.
45+
*
46+
* @param route The route of the API endpoint request
47+
* @param method The method of the API request
48+
* @param statusCode The HTTP response status code of the API request
49+
* @return return true if the record was updated or added, false otherwise
50+
*/
51+
public boolean updateApiAccessIfExpired(String route, String method, int statusCode) {
52+
long currentTime = System.currentTimeMillis();
53+
long hash = computeApiHash(route, method, statusCode);
54+
55+
// New or updated record
56+
boolean isNewOrUpdated = false;
57+
if (!apiAccessMap.containsKey(hash)
58+
|| currentTime - apiAccessMap.get(hash) > expirationTimeInMs) {
59+
60+
cleanupExpiredEntries(currentTime);
61+
62+
apiAccessMap.put(hash, currentTime); // Update timestamp
63+
// move hash to the end of the queue
64+
apiAccessQueue.remove(hash);
65+
apiAccessQueue.addLast(hash);
66+
isNewOrUpdated = true;
67+
68+
// Remove the oldest hash if capacity is reached
69+
while (apiAccessQueue.size() > this.capacity) {
70+
Long oldestHash = apiAccessQueue.pollFirst();
71+
if (oldestHash != null) {
72+
apiAccessMap.remove(oldestHash);
73+
}
74+
}
75+
}
76+
77+
return isNewOrUpdated;
78+
}
79+
80+
private void cleanupExpiredEntries(long currentTime) {
81+
while (!apiAccessQueue.isEmpty()) {
82+
Long oldestHash = apiAccessQueue.peekFirst();
83+
if (oldestHash == null) break;
84+
85+
Long lastAccessTime = apiAccessMap.get(oldestHash);
86+
if (lastAccessTime == null || currentTime - lastAccessTime > expirationTimeInMs) {
87+
apiAccessQueue.pollFirst(); // remove from head
88+
apiAccessMap.remove(oldestHash);
89+
} else {
90+
break; // is up-to-date
91+
}
92+
}
93+
}
94+
95+
private long computeApiHash(String route, String method, int statusCode) {
96+
long result = 17;
97+
result = 31 * result + route.hashCode();
98+
result = 31 * result + method.hashCode();
99+
result = 31 * result + statusCode;
100+
return result;
101+
}
102+
}
Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,38 @@
11
package com.datadog.appsec.api.security;
22

3+
import com.datadog.appsec.gateway.AppSecRequestContext;
34
import datadog.trace.api.Config;
4-
import java.util.concurrent.atomic.AtomicLong;
55

66
public class ApiSecurityRequestSampler {
77

8-
private volatile int sampling;
9-
private final AtomicLong cumulativeCounter = new AtomicLong();
8+
private final ApiAccessTracker apiAccessTracker;
9+
private final Config config;
1010

1111
public ApiSecurityRequestSampler(final Config config) {
12-
sampling = computeSamplingParameter(config.getApiSecurityRequestSampleRate());
12+
this.apiAccessTracker = new ApiAccessTracker();
13+
this.config = config;
1314
}
1415

15-
/**
16-
* Sets the new sampling parameter
17-
*
18-
* @return {@code true} if the value changed
19-
*/
20-
public boolean setSampling(final float newSamplingFloat) {
21-
int newSampling = computeSamplingParameter(newSamplingFloat);
22-
if (newSampling != sampling) {
23-
sampling = newSampling;
24-
cumulativeCounter.set(0); // Reset current sampling counter
25-
return true;
16+
public boolean sampleRequest(AppSecRequestContext ctx) {
17+
if (!config.isApiSecurityEnabled() || ctx == null) {
18+
return false;
2619
}
27-
return false;
28-
}
29-
30-
public int getSampling() {
31-
return sampling;
32-
}
3320

34-
public boolean sampleRequest() {
35-
long prevValue = cumulativeCounter.getAndAdd(sampling);
36-
long newValue = prevValue + sampling;
37-
if (newValue / 100 == prevValue / 100 + 1) {
38-
// Sample request
39-
return true;
21+
String route = ctx.getRoute();
22+
if (route == null) {
23+
return false;
4024
}
41-
// Skipped by sampling
42-
return false;
43-
}
4425

45-
static int computeSamplingParameter(final float pct) {
46-
if (pct >= 1) {
47-
return 100;
26+
String method = ctx.getMethod();
27+
if (method == null) {
28+
return false;
4829
}
49-
if (pct < 0) {
50-
// Api security can only be disabled by setting the sampling to zero, so we set it to 100%.
51-
// TODO: We probably want a warning here.
52-
return 100;
30+
31+
int statusCode = ctx.getResponseStatus();
32+
if (statusCode == 0) {
33+
return false;
5334
}
54-
return (int) (pct * 100);
35+
36+
return apiAccessTracker.updateApiAccessIfExpired(route, method, statusCode);
5537
}
5638
}

dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ public class AppSecRequestContext implements DataBundle, Closeable {
8888
private String scheme;
8989
private String method;
9090
private String savedRawURI;
91+
private String route;
9192
private final Map<String, List<String>> requestHeaders = new LinkedHashMap<>();
9293
private final Map<String, List<String>> responseHeaders = new LinkedHashMap<>();
9394
private volatile Map<String, List<String>> collectedCookies;
@@ -270,15 +271,15 @@ void setScheme(String scheme) {
270271
this.scheme = scheme;
271272
}
272273

273-
String getMethod() {
274+
public String getMethod() {
274275
return method;
275276
}
276277

277278
void setMethod(String method) {
278279
this.method = method;
279280
}
280281

281-
String getSavedRawURI() {
282+
public String getSavedRawURI() {
282283
return savedRawURI;
283284
}
284285

@@ -290,6 +291,18 @@ void setRawURI(String savedRawURI) {
290291
this.savedRawURI = savedRawURI;
291292
}
292293

294+
public String getRoute() {
295+
return route;
296+
}
297+
298+
void setRoute(String route) {
299+
if (this.route != null && this.route.compareToIgnoreCase(route) != 0) {
300+
throw new IllegalStateException(
301+
"Forbidden attempt to set different route for given request context");
302+
}
303+
this.route = route;
304+
}
305+
293306
void addRequestHeader(String name, String value) {
294307
if (finishedRequestHeaders) {
295308
throw new IllegalStateException("Request headers were said to be finished before");

dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -828,8 +828,10 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) {
828828
}
829829
}
830830

831-
if (!spanInfo.isRequiresPostProcessing()) {
832-
ctx.close();
831+
// Route used in post-processing
832+
Object route = spanInfo.getTags().get(Tags.HTTP_ROUTE);
833+
if (route instanceof String) {
834+
ctx.setRoute((String) route);
833835
}
834836
return NoopFlow.INSTANCE;
835837
}
@@ -889,6 +891,7 @@ private void onRequestHeader(RequestContext ctx_, String name, String value) {
889891
}
890892
}
891893

894+
// This handler is executed in a separate thread due the computation possible overhead
892895
private void onPostProcessing(RequestContext ctx_) {
893896
AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC);
894897
if (ctx == null) {
@@ -1068,7 +1071,7 @@ private Flow<Void> maybePublishResponseData(AppSecRequestContext ctx) {
10681071
private void maybeExtractSchemas(AppSecRequestContext ctx) {
10691072
boolean extractSchema = false;
10701073
if (Config.get().isApiSecurityEnabled() && requestSampler != null) {
1071-
extractSchema = requestSampler.sampleRequest();
1074+
extractSchema = requestSampler.sampleRequest(ctx);
10721075
}
10731076

10741077
if (!extractSchema) {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.datadog.appsec.api.security
2+
3+
import datadog.trace.test.util.DDSpecification
4+
5+
class ApiAccessTrackerTest extends DDSpecification {
6+
def "should add new api access and update if expired"() {
7+
given: "An ApiAccessTracker with capacity 2 and expiration time 1 second"
8+
def tracker = new ApiAccessTracker(2, 1000)
9+
10+
when: "Adding new api access"
11+
tracker.updateApiAccessIfExpired("route1", "GET", 200)
12+
def firstAccessTime = tracker.apiAccessMap.values().iterator().next()
13+
14+
then: "The access is added"
15+
tracker.apiAccessMap.size() == 1
16+
17+
when: "Waiting more than expiration time and adding another access with the same key"
18+
Thread.sleep(1100) // Waiting more than 1 second to ensure expiration
19+
tracker.updateApiAccessIfExpired("route1", "GET", 200)
20+
def secondAccessTime = tracker.apiAccessMap.values().iterator().next()
21+
22+
then: "The access is updated and moved to the end"
23+
tracker.apiAccessMap.size() == 1
24+
secondAccessTime > firstAccessTime
25+
}
26+
27+
def "should remove the oldest access when capacity is exceeded"() {
28+
given: "An ApiAccessTracker with capacity 1"
29+
def tracker = new ApiAccessTracker(1, 1000)
30+
31+
when: "Adding two api accesses"
32+
tracker.updateApiAccessIfExpired("route1", "GET", 200)
33+
Thread.sleep(100) // Delay to ensure different timestamps
34+
tracker.updateApiAccessIfExpired("route2", "POST", 404)
35+
36+
then: "The oldest access is removed"
37+
tracker.apiAccessMap.size() == 1
38+
!tracker.apiAccessMap.containsKey(tracker.computeApiHash("route1", "GET", 200))
39+
tracker.apiAccessMap.containsKey(tracker.computeApiHash("route2", "POST", 404))
40+
}
41+
42+
def "should not update access if not expired"() {
43+
given: "An ApiAccessTracker with a short expiration time"
44+
def tracker = new ApiAccessTracker(2, 2000) // 2 seconds expiration
45+
46+
when: "Adding an api access and updating it before it expires"
47+
tracker.updateApiAccessIfExpired("route1", "GET", 200)
48+
def updateTime = System.currentTimeMillis()
49+
boolean updatedBeforeExpiration = tracker.updateApiAccessIfExpired("route1", "GET", 200)
50+
51+
then: "The access is not updated"
52+
!updatedBeforeExpiration
53+
tracker.apiAccessMap.get(tracker.computeApiHash("route1", "GET", 200)) == updateTime
54+
}
55+
}
Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,35 @@
11
package com.datadog.appsec.api.security
22

3+
import com.datadog.appsec.gateway.AppSecRequestContext
34
import datadog.trace.api.Config
45
import datadog.trace.test.util.DDSpecification
5-
import spock.lang.Shared
66

77
class ApiSecurityRequestSamplerTest extends DDSpecification {
88

9-
@Shared
10-
static final float DEFAULT_SAMPLE_RATE = Config.get().getApiSecurityRequestSampleRate()
9+
def config = Mock(Config) {
10+
isApiSecurityEnabled() >> true
11+
}
1112

12-
void 'Api Security Request Sample Rate'() {
13-
given:
14-
def config = Spy(Config.get())
15-
config.getApiSecurityRequestSampleRate() >> sampleRate
16-
def sampler = new ApiSecurityRequestSampler(config)
13+
def sampler = new ApiSecurityRequestSampler(config)
1714

15+
void 'Api Security Sample Request'() {
1816
when:
19-
def numOfRequest = expectedSampledRequests.size()
20-
def results = new int[numOfRequest]
21-
for (int i = 0; i < numOfRequest; i++) {
22-
results[i] = sampler.sampleRequest() ? 1 : 0
17+
def span = Mock(AppSecRequestContext) {
18+
getSavedRawURI() >> route
19+
getMethod() >> method
20+
getResponseStatus() >> statusCode
2321
}
22+
def sample = sampler.sampleRequest(span)
2423

2524
then:
26-
results == expectedSampledRequests as int[]
25+
sample == sampleResult
2726

2827
where:
29-
sampleRate | expectedSampledRequests
30-
DEFAULT_SAMPLE_RATE | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] // Default sample rate - 10%
31-
0.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
32-
0.1 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
33-
0.25 | [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1]
34-
0.33 | [0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1]
35-
0.5 | [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
36-
0.75 | [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
37-
0.9 | [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
38-
0.99 | [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
39-
1.0 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
40-
1.25 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] // Wrong sample rate - use 100%
41-
-0.5 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] // Wrong sample rate - use 100%
42-
}
43-
44-
void 'update sample rate'() {
45-
given:
46-
def config = Spy(Config.get())
47-
def sampler = new ApiSecurityRequestSampler(config)
48-
49-
when:
50-
sampler.setSampling(0.2)
51-
52-
then:
53-
sampler.sampling == 20
28+
method | route | statusCode | sampleResult
29+
'GET' | 'route1' | 200 | true
30+
'GET' | 'route2' | null | false
31+
'GET' | null | 404 | false
32+
'TOP' | 999 | 404 | true
33+
null | '999' | 404 | false
5434
}
55-
}
35+
}

0 commit comments

Comments
 (0)