|
1 | 1 | package com.datadog.appsec.api.security; |
2 | 2 |
|
3 | 3 | import com.datadog.appsec.gateway.AppSecRequestContext; |
4 | | -import datadog.trace.api.Config; |
| 4 | +import datadog.trace.bootstrap.instrumentation.api.Tags; |
| 5 | +import datadog.trace.util.NonBlockingSemaphore; |
| 6 | + |
| 7 | +import java.util.Deque; |
| 8 | +import java.util.Map; |
| 9 | +import java.util.concurrent.ConcurrentHashMap; |
| 10 | +import java.util.concurrent.ConcurrentLinkedDeque; |
5 | 11 |
|
6 | 12 | public class ApiSecurityRequestSampler { |
7 | 13 |
|
8 | | - private final ApiAccessTracker apiAccessTracker; |
9 | | - private final Config config; |
| 14 | + private static final int MAX_POST_PROCESSING_TASKS = 8; |
| 15 | + private static final int INTERVAL_SECONDS = 30; |
| 16 | + private static final int MAX_SIZE = 4096; |
| 17 | + private final Map<Long, Long> apiAccessMap; // Map<hash, timestamp> |
| 18 | + private final Deque<Long> apiAccessQueue; // hashes ordered by access time |
| 19 | + private final long expirationTimeInMs; |
| 20 | + private final int capacity; |
| 21 | + |
| 22 | + private final NonBlockingSemaphore counter = NonBlockingSemaphore.withPermitCount(MAX_POST_PROCESSING_TASKS); |
10 | 23 |
|
11 | | - public ApiSecurityRequestSampler(final Config config) { |
12 | | - this.apiAccessTracker = new ApiAccessTracker(); |
13 | | - this.config = config; |
| 24 | + public ApiSecurityRequestSampler() { |
| 25 | + this(MAX_SIZE, INTERVAL_SECONDS * 1000); |
14 | 26 | } |
15 | 27 |
|
16 | | - public boolean sampleRequest(AppSecRequestContext ctx) { |
| 28 | + public ApiSecurityRequestSampler(int capacity, long expirationTimeInMs) { |
| 29 | + this.capacity = capacity; |
| 30 | + this.expirationTimeInMs = expirationTimeInMs; |
| 31 | + this.apiAccessMap = new ConcurrentHashMap<>(); |
| 32 | + this.apiAccessQueue = new ConcurrentLinkedDeque<>(); |
| 33 | + } |
| 34 | + |
| 35 | + public void preSampleRequest(final AppSecRequestContext ctx, final Map<String, Object> tags) { |
| 36 | + final Object route = tags.get(Tags.HTTP_ROUTE); |
| 37 | + if (route instanceof String) { |
| 38 | + ctx.setRoute((String) route); |
| 39 | + } |
| 40 | + |
17 | 41 | if (!isValid(ctx)) { |
18 | | - return false; |
| 42 | + return; |
19 | 43 | } |
20 | 44 |
|
21 | | - return apiAccessTracker.updateApiAccessIfExpired( |
22 | | - ctx.getRoute(), ctx.getMethod(), ctx.getResponseStatus()); |
| 45 | + if (!isApiAccessExpired(ctx.getRoute(), ctx.getMethod(), ctx.getResponseStatus())) { |
| 46 | + return; |
| 47 | + } |
| 48 | + |
| 49 | + if (counter.acquire()) { |
| 50 | + ctx.setKeepOpenForApiSecurityPostProcessing(true); |
| 51 | + } |
23 | 52 | } |
24 | 53 |
|
25 | | - public boolean preSampleRequest(AppSecRequestContext ctx) { |
| 54 | + public boolean sampleRequest(AppSecRequestContext ctx) { |
26 | 55 | if (!isValid(ctx)) { |
27 | 56 | return false; |
28 | 57 | } |
29 | 58 |
|
30 | | - return apiAccessTracker.isApiAccessExpired( |
| 59 | + return updateApiAccessIfExpired( |
31 | 60 | ctx.getRoute(), ctx.getMethod(), ctx.getResponseStatus()); |
32 | 61 | } |
33 | 62 |
|
34 | 63 | private boolean isValid(AppSecRequestContext ctx) { |
35 | | - return config.isApiSecurityEnabled() |
36 | | - && ctx != null |
| 64 | + return ctx != null |
37 | 65 | && ctx.getRoute() != null |
38 | 66 | && ctx.getMethod() != null |
39 | 67 | && ctx.getResponseStatus() != 0; |
40 | 68 | } |
| 69 | + |
| 70 | + /** |
| 71 | + * Updates the API access log with the given route, method, and status code. If the record already |
| 72 | + * exists and is outdated, it is updated by moving to the end of the list. If the record does not |
| 73 | + * exist, a new record is added. If the capacity limit is reached, the oldest record is removed. |
| 74 | + * This method should not be called concurrently by multiple threads, due absence of additional |
| 75 | + * synchronization for updating data structures is not required. |
| 76 | + * |
| 77 | + * @param route The route of the API endpoint request |
| 78 | + * @param method The method of the API request |
| 79 | + * @param statusCode The HTTP response status code of the API request |
| 80 | + * @return return true if the record was updated or added, false otherwise |
| 81 | + */ |
| 82 | + public boolean updateApiAccessIfExpired(String route, String method, int statusCode) { |
| 83 | + long currentTime = System.currentTimeMillis(); |
| 84 | + long hash = computeApiHash(route, method, statusCode); |
| 85 | + |
| 86 | + // New or updated record |
| 87 | + boolean isNewOrUpdated = false; |
| 88 | + if (!apiAccessMap.containsKey(hash) |
| 89 | + || currentTime - apiAccessMap.get(hash) > expirationTimeInMs) { |
| 90 | + |
| 91 | + cleanupExpiredEntries(currentTime); |
| 92 | + |
| 93 | + apiAccessMap.put(hash, currentTime); // Update timestamp |
| 94 | + // move hash to the end of the queue |
| 95 | + apiAccessQueue.remove(hash); |
| 96 | + apiAccessQueue.addLast(hash); |
| 97 | + isNewOrUpdated = true; |
| 98 | + |
| 99 | + // Remove the oldest hash if capacity is reached |
| 100 | + while (apiAccessMap.size() > this.capacity) { |
| 101 | + Long oldestHash = apiAccessQueue.pollFirst(); |
| 102 | + if (oldestHash != null) { |
| 103 | + apiAccessMap.remove(oldestHash); |
| 104 | + } |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + return isNewOrUpdated; |
| 109 | + } |
| 110 | + |
| 111 | + public boolean isApiAccessExpired(String route, String method, int statusCode) { |
| 112 | + long currentTime = System.currentTimeMillis(); |
| 113 | + long hash = computeApiHash(route, method, statusCode); |
| 114 | + return !apiAccessMap.containsKey(hash) |
| 115 | + || currentTime - apiAccessMap.get(hash) > expirationTimeInMs; |
| 116 | + } |
| 117 | + |
| 118 | + private void cleanupExpiredEntries(long currentTime) { |
| 119 | + while (!apiAccessQueue.isEmpty()) { |
| 120 | + Long oldestHash = apiAccessQueue.peekFirst(); |
| 121 | + if (oldestHash == null) break; |
| 122 | + |
| 123 | + Long lastAccessTime = apiAccessMap.get(oldestHash); |
| 124 | + if (lastAccessTime == null || currentTime - lastAccessTime > expirationTimeInMs) { |
| 125 | + apiAccessQueue.pollFirst(); // remove from head |
| 126 | + apiAccessMap.remove(oldestHash); |
| 127 | + } else { |
| 128 | + break; // is up-to-date |
| 129 | + } |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + private long computeApiHash(String route, String method, int statusCode) { |
| 134 | + long result = 17; |
| 135 | + result = 31 * result + route.hashCode(); |
| 136 | + result = 31 * result + method.hashCode(); |
| 137 | + result = 31 * result + statusCode; |
| 138 | + return result; |
| 139 | + } |
| 140 | + |
41 | 141 | } |
0 commit comments