|
| 1 | +/* |
| 2 | + * dCache - http://www.dcache.org/ |
| 3 | + * |
| 4 | + * Copyright (C) 2025 Deutsches Elektronen-Synchrotron |
| 5 | + * |
| 6 | + * This program is free software: you can redistribute it and/or modify |
| 7 | + * it under the terms of the GNU Affero General Public License as |
| 8 | + * published by the Free Software Foundation, either version 3 of the |
| 9 | + * License, or (at your option) any later version. |
| 10 | + * |
| 11 | + * This program is distributed in the hope that it will be useful, |
| 12 | + * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 13 | + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 14 | + * GNU Affero General Public License for more details. |
| 15 | + * |
| 16 | + * You should have received a copy of the GNU Affero General Public License |
| 17 | + * along with this program. If not, see <http://www.gnu.org/licenses/>. |
| 18 | + */ |
| 19 | +package org.dcache.util.jetty; |
| 20 | + |
| 21 | +import com.google.common.annotations.VisibleForTesting; |
| 22 | +import com.google.common.cache.Cache; |
| 23 | +import com.google.common.cache.CacheBuilder; |
| 24 | +import com.google.common.util.concurrent.RateLimiter; |
| 25 | +import dmg.cells.nucleus.CellCommandListener; |
| 26 | +import dmg.util.command.Command; |
| 27 | +import dmg.util.command.Option; |
| 28 | +import org.eclipse.jetty.http.HttpStatus; |
| 29 | +import org.eclipse.jetty.server.Handler; |
| 30 | +import org.eclipse.jetty.server.Request; |
| 31 | +import org.eclipse.jetty.server.handler.HandlerCollection; |
| 32 | +import org.slf4j.Logger; |
| 33 | +import org.slf4j.LoggerFactory; |
| 34 | + |
| 35 | +import javax.servlet.ServletException; |
| 36 | +import javax.servlet.http.HttpServletRequest; |
| 37 | +import javax.servlet.http.HttpServletResponse; |
| 38 | +import java.io.IOException; |
| 39 | +import java.time.Duration; |
| 40 | +import java.time.temporal.ChronoUnit; |
| 41 | +import java.util.concurrent.Callable; |
| 42 | +import java.util.concurrent.atomic.AtomicInteger; |
| 43 | + |
| 44 | +import static com.google.common.base.Preconditions.checkArgument; |
| 45 | + |
| 46 | +/** |
| 47 | + * A Jetty handler collection that enforces per-client and global rate limiting, |
| 48 | + * as well as temporary blocking of clients that exceed error thresholds. |
| 49 | + * <p> |
| 50 | + * Each client IP is assigned a rate limiter and error counter. If a client exceeds |
| 51 | + * the allowed number of errors within a short time window, it is blocked for a fixed duration. |
| 52 | + * Requests are rejected with HTTP 429 if rate limits are exceeded or the client is blocked. |
| 53 | + * <p> |
| 54 | + * The handler uses Guava caches to manage rate limiters, error counters, and blocked clients, |
| 55 | + * automatically expiring idle entries. |
| 56 | + * |
| 57 | + * Based on original code by Sandro Grizzo. |
| 58 | + */ |
| 59 | +public class RateLimitedHandlerList extends HandlerCollection implements CellCommandListener { |
| 60 | + |
| 61 | + private static final Logger LOGGER = LoggerFactory.getLogger(RateLimitedHandlerList.class); |
| 62 | + |
| 63 | + /** |
| 64 | + * Initial capacity of the client IP rates limiters map size. |
| 65 | + */ |
| 66 | + private final int CLIENT_IP_CACHE_INITIAL_CAPACITY = 1024; |
| 67 | + |
| 68 | + /** |
| 69 | + * Maximum number of errors allowed per client before blocking. |
| 70 | + */ |
| 71 | + private int maxErrorsPerClient; |
| 72 | + |
| 73 | + /** |
| 74 | + * An object used as a value when client is blocked. |
| 75 | + */ |
| 76 | + private final Object BLOCK = new Object(); |
| 77 | + |
| 78 | + /** |
| 79 | + * Calculated per-client rate limit based on the global rate limit and factor. |
| 80 | + */ |
| 81 | + private double perClientRate; |
| 82 | + |
| 83 | + /** |
| 84 | + * Rate limiter for all requests. |
| 85 | + */ |
| 86 | + private final RateLimiter globalRateLimiter; |
| 87 | + |
| 88 | + /** |
| 89 | + * Cache mapping client identifiers (e.g., IP addresses) to their respective rate limiters. |
| 90 | + */ |
| 91 | + private final Cache<String, RateLimiter> perClientRates ; |
| 92 | + |
| 93 | + /** |
| 94 | + * Cache mapping client identifiers to a blocking marker object for temporarily blocked clients. |
| 95 | + */ |
| 96 | + private final Cache<String, Object> blockedClients; |
| 97 | + |
| 98 | + /** |
| 99 | + * Cache mapping client identifiers to their respective error counters. |
| 100 | + */ |
| 101 | + private final Cache<String, AtomicInteger> perClientErrorCount; |
| 102 | + |
| 103 | + public static class Configuration { |
| 104 | + private int maxClientsToTrack; |
| 105 | + private long maxGlobalRequestsPerSecond; |
| 106 | + private int maxErrorsPerClient; |
| 107 | + private int perClientPercent; |
| 108 | + private long clientIdleTime; |
| 109 | + private ChronoUnit clientIdleTimeUnit; |
| 110 | + private long clientBlockingTime; |
| 111 | + private ChronoUnit clientBlockingTimeUnit; |
| 112 | + private long errorAcceptanceWindow; |
| 113 | + private ChronoUnit errorAcceptanceWindowUnit; |
| 114 | + |
| 115 | + public void setGlobalRequestsPerSecond(long value) { |
| 116 | + this.maxGlobalRequestsPerSecond = value; |
| 117 | + |
| 118 | + } |
| 119 | + |
| 120 | + public void setNumErrorsBeforeBlocking(int value) { |
| 121 | + this.maxErrorsPerClient = value; |
| 122 | + } |
| 123 | + |
| 124 | + public void setLimitPercentagePerClient(int value) { |
| 125 | + this.perClientPercent = value; |
| 126 | + } |
| 127 | + |
| 128 | + public void setClientIdleTime(long value) { |
| 129 | + this.clientIdleTime = value; |
| 130 | + } |
| 131 | + |
| 132 | + public void setClientBlockingTime(long clientBlockingTime) { |
| 133 | + this.clientBlockingTime = clientBlockingTime; |
| 134 | + } |
| 135 | + |
| 136 | + public void setErrorCountingWindow(long errorAcceptanceWindow) { |
| 137 | + this.errorAcceptanceWindow = errorAcceptanceWindow; |
| 138 | + } |
| 139 | + |
| 140 | + public void setClientIdleTimeUnit(ChronoUnit clientIdleTimeUnit) { |
| 141 | + this.clientIdleTimeUnit = clientIdleTimeUnit; |
| 142 | + } |
| 143 | + |
| 144 | + public void setClientBlockingTimeUnit(ChronoUnit clientBlockingTimeUnit) { |
| 145 | + this.clientBlockingTimeUnit = clientBlockingTimeUnit; |
| 146 | + } |
| 147 | + |
| 148 | + public void setErrorCountingWindowUnit(ChronoUnit errorAcceptanceWindowUnit) { |
| 149 | + this.errorAcceptanceWindowUnit = errorAcceptanceWindowUnit; |
| 150 | + } |
| 151 | + |
| 152 | + public void setMaxClientsToTrack(int maxClientsToTrack) { |
| 153 | + this.maxClientsToTrack = maxClientsToTrack; |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + |
| 158 | + /** * Constructs a RateLimitedHandlerList with parameters from the given configuration. |
| 159 | + * |
| 160 | + * @param configuration the configuration object containing rate limiting and blocking parameters |
| 161 | + */ |
| 162 | + public RateLimitedHandlerList(Configuration configuration) { |
| 163 | + this(configuration.maxClientsToTrack, |
| 164 | + configuration.maxGlobalRequestsPerSecond, |
| 165 | + configuration.maxErrorsPerClient, |
| 166 | + configuration.perClientPercent, |
| 167 | + Duration.of(configuration.clientIdleTime, configuration.clientIdleTimeUnit), |
| 168 | + Duration.of(configuration.clientBlockingTime, configuration.clientBlockingTimeUnit), |
| 169 | + Duration.of(configuration.errorAcceptanceWindow, configuration.errorAcceptanceWindowUnit)); |
| 170 | + } |
| 171 | + |
| 172 | + /** * Constructs a RateLimitedHandlerList with specified rate limiting and blocking parameters. |
| 173 | + * |
| 174 | + * @param maxGlobalRequestsPerSecond maximum number of requests per second allowed globally |
| 175 | + * @param maxErrorsPerClient maximum number of errors allowed per client before blocking |
| 176 | + * @param perClientPercent percentage of the global rate limit to apply per client (1 < percent <= 100) |
| 177 | + * @param clientIdleTime duration after which an idle client's rate limiter is removed |
| 178 | + * @param clientBlockingTime duration for which a client is blocked after exceeding error threshold |
| 179 | + * @param errorAcceptanceWindow time window for counting errors per client |
| 180 | + */ |
| 181 | + public RateLimitedHandlerList( |
| 182 | + int maxClientsToTrack, |
| 183 | + long maxGlobalRequestsPerSecond, |
| 184 | + int maxErrorsPerClient, |
| 185 | + int perClientPercent, |
| 186 | + Duration clientIdleTime, |
| 187 | + Duration clientBlockingTime, |
| 188 | + Duration errorAcceptanceWindow) { |
| 189 | + |
| 190 | + perClientRates = CacheBuilder.newBuilder() |
| 191 | + .initialCapacity(CLIENT_IP_CACHE_INITIAL_CAPACITY) |
| 192 | + .maximumSize(maxClientsToTrack) |
| 193 | + .expireAfterAccess(clientIdleTime) |
| 194 | + .build(); |
| 195 | + |
| 196 | + blockedClients = CacheBuilder.newBuilder() |
| 197 | + .maximumSize(maxClientsToTrack) |
| 198 | + .expireAfterWrite(clientBlockingTime) |
| 199 | + .build(); |
| 200 | + |
| 201 | + perClientErrorCount = CacheBuilder.newBuilder() |
| 202 | + .initialCapacity(CLIENT_IP_CACHE_INITIAL_CAPACITY) |
| 203 | + .maximumSize(maxClientsToTrack) |
| 204 | + .expireAfterAccess(errorAcceptanceWindow) |
| 205 | + .build(); |
| 206 | + |
| 207 | + globalRateLimiter = RateLimiter.create(maxGlobalRequestsPerSecond); |
| 208 | + perClientRate = perClientPercent * maxGlobalRequestsPerSecond / 100.0d; |
| 209 | + this.maxErrorsPerClient = maxErrorsPerClient; |
| 210 | + } |
| 211 | + |
| 212 | + |
| 213 | + @Override |
| 214 | + public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { |
| 215 | + |
| 216 | + String client = getClientIp(request); |
| 217 | + |
| 218 | + boolean blocked = blockedClients.getIfPresent(client) != null; |
| 219 | + if (blocked) { |
| 220 | + LOGGER.warn("Blocking client with too many auth errors {}", client); |
| 221 | + response.setStatus(HttpStatus.TOO_MANY_REQUESTS_429); |
| 222 | + response.getWriter().write("Server is busy. Please try again later."); |
| 223 | + baseRequest.setHandled(true); |
| 224 | + return; |
| 225 | + } |
| 226 | + |
| 227 | + if (!getClientRateLimiter(client).tryAcquire()) { |
| 228 | + LOGGER.warn("Blocking client with too many requests {}", client); |
| 229 | + response.setStatus(HttpStatus.TOO_MANY_REQUESTS_429); |
| 230 | + response.getWriter().write("Server is busy. Please try again later."); |
| 231 | + baseRequest.setHandled(true); |
| 232 | + return; |
| 233 | + } |
| 234 | + |
| 235 | + if (!globalRateLimiter.tryAcquire()) { |
| 236 | + LOGGER.warn("Blocking client due to globally too many requests {}", client); |
| 237 | + response.setStatus(HttpStatus.TOO_MANY_REQUESTS_429); |
| 238 | + response.getWriter().write("Server is busy. Please try again later."); |
| 239 | + baseRequest.setHandled(true); |
| 240 | + return; |
| 241 | + } |
| 242 | + |
| 243 | + Handler[] handlers = this.getHandlers(); |
| 244 | + if (handlers != null && this.isStarted()) { |
| 245 | + for (Handler handler : handlers) { |
| 246 | + handler.handle(target, baseRequest, request, response); |
| 247 | + if (baseRequest.isHandled()) { |
| 248 | + // block clients that hammer with authentication failures |
| 249 | + if (response.getStatus() >= 400 && response.getStatus() <= 407) { |
| 250 | + int errors = getClientErrorRateLimiter(client).incrementAndGet(); |
| 251 | + if (errors >= maxErrorsPerClient) { |
| 252 | + blockedClients.put(client, BLOCK); |
| 253 | + // as client blocked, no reason to keep track of further errors |
| 254 | + perClientErrorCount.invalidate(client); |
| 255 | + perClientRates.invalidate(client); |
| 256 | + } |
| 257 | + } |
| 258 | + return; |
| 259 | + } |
| 260 | + } |
| 261 | + } |
| 262 | + } |
| 263 | + |
| 264 | + /** |
| 265 | + * Extracts the client IP address from the request, considering possible proxies. |
| 266 | + * |
| 267 | + * @param request the HTTP request |
| 268 | + * @return the client IP address |
| 269 | + */ |
| 270 | + private String getClientIp(HttpServletRequest request) { |
| 271 | + String forwardedIp = request.getHeader("X-Forwarded-For"); |
| 272 | + if (forwardedIp == null) { |
| 273 | + return request.getRemoteAddr(); |
| 274 | + } |
| 275 | + return forwardedIp.split(",")[0]; |
| 276 | + } |
| 277 | + |
| 278 | + /** |
| 279 | + * Retrieves or creates a RateLimiter for the specified client. |
| 280 | + * |
| 281 | + * @param client the client identifier (e.g., IP address) |
| 282 | + * @return the RateLimiter for the client |
| 283 | + */ |
| 284 | + private RateLimiter getClientRateLimiter(String client) { |
| 285 | + try { |
| 286 | + return perClientRates.get(client, () -> RateLimiter.create(perClientRate)); |
| 287 | + } catch (Exception e) { |
| 288 | + // should not happen |
| 289 | + throw new RuntimeException("Failed to get or create rate limiter for client " + client, e); |
| 290 | + } |
| 291 | + } |
| 292 | + |
| 293 | + /** |
| 294 | + * Retrieves or creates an AtomicInteger to count errors for the specified client. |
| 295 | + * |
| 296 | + * @param client the client identifier (e.g., IP address) |
| 297 | + * @return the AtomicInteger counting errors for the client |
| 298 | + */ |
| 299 | + private AtomicInteger getClientErrorRateLimiter(String client) { |
| 300 | + try { |
| 301 | + return perClientErrorCount.get(client, () -> new AtomicInteger(0)); |
| 302 | + } catch (Exception e) { |
| 303 | + // should not happen |
| 304 | + throw new RuntimeException("Failed to get or create error counter per client client " + client, e); |
| 305 | + } |
| 306 | + } |
| 307 | + |
| 308 | + @Override |
| 309 | + public String toString() { |
| 310 | + return String.format("RateLimitedHandlerList{globalRate=%.1f, perClientRate=%.1f, maxErrorsPerClient=%d}", |
| 311 | + globalRateLimiter.getRate(), perClientRate, maxErrorsPerClient); |
| 312 | + } |
| 313 | + |
| 314 | + @VisibleForTesting |
| 315 | + void setMaxGlobalRequestsPerSecond(int maxRequestsPerSecond) { |
| 316 | + checkArgument(maxRequestsPerSecond > 0, "maxRequestsPerSecond must be positive"); |
| 317 | + globalRateLimiter.setRate(maxRequestsPerSecond); |
| 318 | + } |
| 319 | + |
| 320 | + @Command(name="limits reset", description="Reset all rate limiters and error counters") |
| 321 | + public class LimitsResetCommand implements Callable<String> { |
| 322 | + @Override |
| 323 | + public String call() { |
| 324 | + perClientRates.invalidateAll(); |
| 325 | + blockedClients.invalidateAll(); |
| 326 | + perClientErrorCount.invalidateAll(); |
| 327 | + |
| 328 | + return ""; |
| 329 | + } |
| 330 | + } |
| 331 | + |
| 332 | + |
| 333 | + @Command(name="limits info", description="Show current rate limits and statistics. The retuned information is approximate.") |
| 334 | + public class LimitsShowCommand implements Callable<String> { |
| 335 | + |
| 336 | + @Option(name="l", usage="Verbose listing") |
| 337 | + boolean verbose = false; |
| 338 | + |
| 339 | + @Override |
| 340 | + public String call() { |
| 341 | + StringBuilder sb = new StringBuilder(); |
| 342 | + sb.append(String.format("Global rate: %.1f requests/second\n", globalRateLimiter.getRate())); |
| 343 | + sb.append(String.format("Per-client rate: %.1f requests/second\n", perClientRate)); |
| 344 | + sb.append(String.format("Max errors per client before blocking: %d\n", maxErrorsPerClient)); |
| 345 | + sb.append(String.format("Currently blocked clients: ~%d\n", blockedClients.size())); |
| 346 | + if (verbose) { |
| 347 | + sb.append(" Blocked clients:\n"); |
| 348 | + blockedClients.asMap().keySet().forEach(client -> sb.append(" ").append(client).append("\n")); |
| 349 | + } |
| 350 | + sb.append(String.format("Tracked clients with rate limiters: ~%d\n", perClientRates.size())); |
| 351 | + if (verbose) { |
| 352 | + sb.append(" Clients with rate limiters:\n"); |
| 353 | + perClientRates.asMap().keySet().forEach(client -> sb.append(" ").append(client).append("\n")); |
| 354 | + } |
| 355 | + sb.append(String.format("Tracked clients with error counters: ~%d\n", perClientErrorCount.size())); |
| 356 | + if (verbose) { |
| 357 | + sb.append(" Clients with error counters:\n"); |
| 358 | + perClientErrorCount.asMap().forEach((client, counter) -> |
| 359 | + sb.append(" ").append(client).append(": ").append(counter.get()).append(" errors\n")); |
| 360 | + } |
| 361 | + return sb.toString(); |
| 362 | + } |
| 363 | + } |
| 364 | +} |
0 commit comments