diff --git a/api/ee/src/core/entitlements/types.py b/api/ee/src/core/entitlements/types.py index d64e90740d..2b3f00e3a3 100644 --- a/api/ee/src/core/entitlements/types.py +++ b/api/ee/src/core/entitlements/types.py @@ -9,6 +9,7 @@ class Tracker(str, Enum): FLAGS = "flags" COUNTERS = "counters" GAUGES = "gauges" + THROTTLES = "throttles" class Flag(str, Enum): @@ -47,6 +48,73 @@ class Probe(BaseModel): delta: Optional[bool] = False +class Bucket(BaseModel): + capacity: Optional[int] = None # max tokens in the bucket + rate: Optional[int] = None # tokens added per minute + algorithm: Optional[str] = None + + +class Category(str, Enum): + STANDARD = "standard" + CORE_FAST = "core_fast" + CORE_SLOW = "core_slow" + TRACING_FAST = "tracing_fast" + TRACING_SLOW = "tracing_slow" + SERVICES_FAST = "services_fast" + SERVICES_SLOW = "services_slow" + + +class Method(str, Enum): + POST = "post" + GET = "get" + PUT = "put" + PATCH = "patch" + DELETE = "delete" + QUERY = "query" + MUTATION = "mutation" + ANY = "any" + + +class Mode(str, Enum): + INCLUDE = "include" + EXCLUDE = "exclude" + + +class Throttle(BaseModel): + bucket: Bucket + mode: Mode + categories: list[Category] | None = None + endpoints: list[tuple[Method, str]] | None = None + + +ENDPOINTS = { + Category.CORE_FAST: [ + (Method.POST, "*/retrieve"), + ], + Category.CORE_SLOW: [ + # None defined yet + ], + Category.TRACING_FAST: [ + (Method.POST, "/otlp/v1/traces"), + ], + Category.TRACING_SLOW: [ + (Method.POST, "/tracing/*/query"), + # + (Method.POST, "/tracing/spans/analytics"), # LEGACY + ], + Category.SERVICES_FAST: [ + (Method.ANY, "/permissions/verify"), + ], + Category.SERVICES_SLOW: [ + # None defined yet + ], + Category.STANDARD: [ + # None defined yet + # CATCH ALL + ], +} + + CATALOG = [ { "title": "Hobby", @@ -216,6 +284,42 @@ class Probe(BaseModel): Gauge.USERS: Quota(limit=2, strict=True, free=2), Gauge.APPLICATIONS: Quota(strict=True), }, + Tracker.THROTTLES: [ + Throttle( + categories=[ + Category.STANDARD, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=120, + rate=120, + ), + ), + Throttle( + categories=[ + Category.CORE_FAST, + Category.TRACING_FAST, + Category.SERVICES_FAST, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=1200, + rate=1200, + ), + ), + Throttle( + categories=[ + Category.CORE_SLOW, + Category.TRACING_SLOW, + Category.SERVICES_SLOW, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=120, + rate=1, + ), + ), + ], }, Plan.CLOUD_V0_PRO: { Tracker.FLAGS: { @@ -231,6 +335,42 @@ class Probe(BaseModel): Gauge.USERS: Quota(limit=10, strict=True, free=3), Gauge.APPLICATIONS: Quota(strict=True), }, + Tracker.THROTTLES: [ + Throttle( + categories=[ + Category.STANDARD, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=360, + rate=360, + ), + ), + Throttle( + categories=[ + Category.CORE_FAST, + Category.TRACING_FAST, + Category.SERVICES_FAST, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=3600, + rate=3600, + ), + ), + Throttle( + categories=[ + Category.CORE_SLOW, + Category.TRACING_SLOW, + Category.SERVICES_SLOW, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=180, + rate=1, + ), + ), + ], }, Plan.CLOUD_V0_BUSINESS: { Tracker.FLAGS: { @@ -246,6 +386,42 @@ class Probe(BaseModel): Gauge.USERS: Quota(strict=True), Gauge.APPLICATIONS: Quota(strict=True), }, + Tracker.THROTTLES: [ + Throttle( + categories=[ + Category.STANDARD, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=3600, + rate=3600, + ), + ), + Throttle( + categories=[ + Category.CORE_FAST, + Category.TRACING_FAST, + Category.SERVICES_FAST, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=36000, + rate=36000, + ), + ), + Throttle( + categories=[ + Category.CORE_SLOW, + Category.TRACING_SLOW, + Category.SERVICES_SLOW, + ], + mode=Mode.INCLUDE, + bucket=Bucket( + capacity=1800, + rate=1, + ), + ), + ], }, Plan.CLOUD_V0_HUMANITY_LABS: { Tracker.FLAGS: { diff --git a/api/ee/src/core/subscriptions/service.py b/api/ee/src/core/subscriptions/service.py index 74c331b08b..d614ddaf2b 100644 --- a/api/ee/src/core/subscriptions/service.py +++ b/api/ee/src/core/subscriptions/service.py @@ -20,6 +20,7 @@ ) from ee.src.core.subscriptions.interfaces import SubscriptionsDAOInterface from ee.src.core.entitlements.service import EntitlementsService +from oss.src.utils.caching import invalidate_cache from ee.src.core.meters.service import MetersService log = get_module_logger(__name__) @@ -71,7 +72,13 @@ async def update( *, subscription: SubscriptionDTO, ) -> Optional[SubscriptionDTO]: - return await self.subscriptions_dao.update(subscription=subscription) + updated = await self.subscriptions_dao.update(subscription=subscription) + if updated: + await invalidate_cache( + namespace="entitlements:subscription", + key={"organization_id": str(updated.organization_id)}, + ) + return updated async def start_reverse_trial( self, diff --git a/api/ee/src/services/throttling_service.py b/api/ee/src/services/throttling_service.py new file mode 100644 index 0000000000..a3abd8237c --- /dev/null +++ b/api/ee/src/services/throttling_service.py @@ -0,0 +1,293 @@ +from typing import Optional +from uuid import UUID +from fnmatch import fnmatchcase + +from fastapi import Request +from fastapi.responses import JSONResponse + +from oss.src.utils.caching import get_cache, set_cache +from oss.src.utils.logging import get_module_logger +from oss.src.utils.throttling import Algorithm, check_throttles + +from ee.src.core.entitlements.types import ( + ENTITLEMENTS, + ENDPOINTS, + Category, + Method, + Mode, + Throttle, + Tracker, +) +from ee.src.core.meters.service import MetersService +from ee.src.core.subscriptions.service import SubscriptionsService +from ee.src.core.subscriptions.types import Plan +from ee.src.dbs.postgres.meters.dao import MetersDAO +from ee.src.dbs.postgres.subscriptions.dao import SubscriptionsDAO + +log = get_module_logger(__name__) + +meters_service = MetersService( + meters_dao=MetersDAO(), +) + +subscriptions_service = SubscriptionsService( + subscriptions_dao=SubscriptionsDAO(), + meters_service=meters_service, +) + + +def _normalize_path(request: Request) -> str: + path = request.url.path + root_path = request.scope.get("root_path") + if root_path and path.startswith(root_path): + path = path[len(root_path) :] or "/" + return path + + +def _matches_endpoint( + method: str, + path: str, + endpoint_method: Method, + endpoint_path: str, +) -> bool: + if endpoint_method != Method.ANY and endpoint_method.value != method: + return False + + if "*" in endpoint_path: + return fnmatchcase(path, endpoint_path) + + return path == endpoint_path + + +def _resolve_categories( + method: str, + path: str, +) -> set[Category]: + categories: set[Category] = set() + + for category, endpoints in ENDPOINTS.items(): + for endpoint_method, endpoint_path in endpoints: + if _matches_endpoint(method, path, endpoint_method, endpoint_path): + categories.add(category) + break + + if not categories: + categories.add(Category.STANDARD) + + return categories + + +def _throttle_matches( + throttle: Throttle, + categories: set[Category], + method: str, + path: str, +) -> bool: + category_match = False + endpoint_match = False + + if throttle.categories: + category_match = any(category in categories for category in throttle.categories) + + if throttle.endpoints: + endpoint_match = any( + _matches_endpoint(method, path, endpoint_method, endpoint_path) + for endpoint_method, endpoint_path in throttle.endpoints + ) + + if throttle.categories is None and throttle.endpoints is None: + match = True + else: + match = category_match or endpoint_match + + if throttle.mode == Mode.INCLUDE: + return match + + if throttle.mode == Mode.EXCLUDE: + return not match + + return False + + +def _throttle_suffix( + throttle: Throttle, + matched_categories: Optional[set[Category]] = None, +) -> str: + if throttle.categories: + categories_source = ( + matched_categories if matched_categories else set(throttle.categories) + ) + categories = ",".join(sorted(category.value for category in categories_source)) + return f"cats:{categories}" + + if throttle.endpoints: + endpoints = ",".join( + sorted(f"{method.value}:{path}" for method, path in throttle.endpoints) + ) + return f"epts:{endpoints}" + + return "all" + + +async def _get_plan(organization_id: str) -> Optional[Plan]: + cache_key = { + "organization_id": organization_id, + } + + subscription_data = await get_cache( + namespace="entitlements:subscription", + key=cache_key, + ) + + if subscription_data is None: + subscription = await subscriptions_service.read( + organization_id=organization_id, + ) + + if not subscription: + return None + + subscription_data = { + "plan": subscription.plan.value, + } + + await set_cache( + namespace="entitlements:subscription", + key=cache_key, + value=subscription_data, + ) + + plan_value = subscription_data.get("plan") if subscription_data else None + if not plan_value: + return None + + try: + return Plan(plan_value) + + except ValueError: + log.warning("[throttle] Unknown plan", plan=plan_value) + + return None + + +async def throttling_middleware(request: Request, call_next): + if hasattr(request.state, "admin") and request.state.admin: + return await call_next(request) + + organization_id = ( + request.state.organization_id + if hasattr(request.state, "organization_id") + else None + ) + + if not organization_id: + return await call_next(request) + + plan = await _get_plan(str(organization_id)) + + if not plan or plan not in ENTITLEMENTS: + return await call_next(request) + + throttles: list[Throttle] = ENTITLEMENTS[plan].get(Tracker.THROTTLES) or [] + + if not throttles: + return await call_next(request) + + method = request.method.lower() + + path = _normalize_path(request) + + # log.debug( + # "[throttling] START", org=organization_id, plan=plan, method=method, path=path + # ) + + categories = _resolve_categories(method, path) + + checks: list[tuple[dict, int, int]] = [] + + for throttle in throttles: + if throttle.bucket.capacity is None or throttle.bucket.rate is None: + continue + + if not _throttle_matches(throttle, categories, method, path): + continue + + matched_categories = None + if throttle.categories: + matched_categories = categories.intersection(throttle.categories) + + key = { + "organization": str(organization_id), + "plan": plan.value, + "policy": _throttle_suffix(throttle, matched_categories=matched_categories), + } + + capacity = throttle.bucket.capacity + rate = throttle.bucket.rate + + if capacity <= 0 or rate <= 0: + continue + + checks.append((key, capacity, rate)) + + if not checks: + return await call_next(request) + + # Use GCRA by default (fast, smooth scheduling) unless explicitly configured + # All throttles in current entitlements use the same algorithm + algorithm = Algorithm.GCRA + if throttles and throttles[0].bucket.algorithm: + algo_str = throttles[0].bucket.algorithm.lower() + if algo_str == "tbra": + algorithm = Algorithm.TBRA + + # log.debug("[throttling] CHECK", org=organization_id, plan=plan, checks=checks) + + results = await check_throttles(checks, algorithm=algorithm) + + # Track minimum remaining tokens across all policies for the response header + min_remaining: int | None = None + + for idx, result in enumerate(results): + remaining = int(result.tokens_remaining or 0) + + if not result.allow: + key, capacity, rate = checks[idx] + + headers = { + "X-RateLimit-Limit": str(capacity), + "X-RateLimit-Remaining": str(remaining), + } + retry_after = ( + int(result.retry_after_seconds) + 1 + if result.retry_after_seconds > 0 + else None + ) + if retry_after: + headers["Retry-After"] = str(retry_after) + + detail = ( + f"Rate limit exceeded. Please retry after {retry_after} seconds." + if retry_after + else "Rate limit exceeded. Please try again later." + ) + + return JSONResponse( + status_code=429, + content={"detail": detail}, + headers=headers, + ) + + # Track minimum remaining across all allowed policies + if min_remaining is None or remaining < min_remaining: + min_remaining = remaining + + # log.debug("[throttling] ALLOW") + + response = await call_next(request) + + # Add rate limit header to successful responses + if min_remaining is not None: + response.headers["X-RateLimit-Remaining"] = str(min_remaining) + + return response diff --git a/api/entrypoints/routers.py b/api/entrypoints/routers.py index 9d12da4865..1e85435238 100644 --- a/api/entrypoints/routers.py +++ b/api/entrypoints/routers.py @@ -157,8 +157,13 @@ async def lifespan(*args, **kwargs): ) # MIDDLEWARE ------------------------------------------------------------------- -app.middleware("http")(authentication_middleware) +if is_ee(): + from ee.src.services.throttling_service import throttling_middleware + + app.middleware("http")(throttling_middleware) + +app.middleware("http")(authentication_middleware) app.middleware("http")(analytics_middleware) app.add_middleware( diff --git a/api/oss/src/utils/throttling.py b/api/oss/src/utils/throttling.py new file mode 100644 index 0000000000..ff76da5755 --- /dev/null +++ b/api/oss/src/utils/throttling.py @@ -0,0 +1,586 @@ +""" +API Rate Limiting (Throttling) via Redis. + +Three-layer architecture: + +Layer 1 (Scripts): Raw Redis Lua execution + Methods: + - execute_tbra(key, capacity, rate) + - execute_gcra(key, interval, tolerance) + Details: + - Computes current time internally + +Layer 2 (Library): Public API with precomputation + Methods: + - check_throttle(key, max_capacity, refill_rate, ...) + Details: + - Accepts key as str or dict + - Converts to algorithm-specific params + - Handles failures + +Layer 3 (User code): Middleware/decorators that resolve: + Methods: + - + Details: + - key: from endpoint, org_id, user_id, headers, etc + - params: from plan lookup, config, callbacks, etc + +Usage (simple - global limit): + result = await check_throttle("global", max_capacity=1000, refill_rate=100) + +Usage (with dict key): + result = await check_throttle({"org": org_id}, max_capacity=100, refill_rate=60) + +Usage (with multiple dimensions): + result = await check_throttle({"ep": endpoint, "org": org_id}, ...) +""" + +import time +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Callable, Awaitable, Any, Union + +from redis.asyncio import Redis + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.env import env + +log = get_module_logger(__name__) + + +# ============================================================================= +# Configuration +# ============================================================================= + +THROTTLE_DEBUG = False +THROTTLE_SOCKET_TIMEOUT = 0.1 + +# Time step: 1 second (can be 100ms for finer granularity) +TIME_STEP_MS = 1000 + +# Fixed-point scale for TBRA +_SCALE = 1000 + +# TTL: 60 minutes +_TTL_MS = 3600000 + +# Redis client +_redis: Optional[Redis] = None + + +def _get_redis() -> Redis: + global _redis + if _redis is None: + _redis = Redis.from_url( + url=env.redis.uri_volatile, + decode_responses=False, + socket_timeout=THROTTLE_SOCKET_TIMEOUT, + ) + return _redis + + +def _now_step() -> int: + """Current time as quantized step.""" + return int(time.time() * 1000) // TIME_STEP_MS + + +# ============================================================================= +# Layer 1: Lua Scripts (raw Redis execution) +# ============================================================================= + +_LUA_TBRA = """ +local key = KEYS[1] +local max_cap = tonumber(ARGV[1]) +local rate = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) + +local val = redis.call('GET', key) +local tokens, last + +if val then + local sep = string.find(val, '|') + tokens = tonumber(string.sub(val, 1, sep - 1)) + last = tonumber(string.sub(val, sep + 1)) +else + tokens = max_cap + last = now +end + +local elapsed = now - last +if elapsed > 0 then + tokens = tokens + elapsed * rate + if tokens > max_cap then tokens = max_cap end +end + +tokens = tokens - 1000 +local allow = tokens >= 0 and 1 or 0 +local retry = allow == 1 and 0 or math.ceil(-tokens / rate) + +redis.call('SET', key, tokens .. '|' .. now, 'PX', 3600000) + +return {allow, tokens, retry} +""" + +_LUA_GCRA = """ +local key = KEYS[1] +local interval = tonumber(ARGV[1]) +local tolerance = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) + +local tat = tonumber(redis.call('GET', key)) or now + +local limit = tat - tolerance +local allow, retry, new_tat, remaining + +if now < limit then + allow = 0 + retry = limit - now + new_tat = tat + remaining = 0 +else + allow = 1 + retry = 0 + new_tat = (tat > now and tat or now) + interval + -- Remaining burst capacity: how many more requests before hitting the limit + -- remaining = (tolerance - (new_tat - now)) / interval + local used = new_tat - now + if used < tolerance then + remaining = math.floor((tolerance - used) / interval) + else + remaining = 0 + end +end + +redis.call('SET', key, new_tat, 'PX', 3600000) + +return {allow, remaining, retry} +""" + +_sha_tbra: Optional[str] = None +_sha_gcra: Optional[str] = None + + +async def _ensure_scripts() -> tuple[str, str]: + global _sha_tbra, _sha_gcra + + r = _get_redis() + + if _sha_tbra is None or _sha_gcra is None: + _sha_tbra = await r.script_load(_LUA_TBRA) + _sha_gcra = await r.script_load(_LUA_GCRA) + + return str(_sha_tbra), str(_sha_gcra) + + +async def _exec_script(sha: str, key: str, *args) -> list: + global _sha_tbra, _sha_gcra + + r = _get_redis() + + try: + return await r.evalsha(sha, 1, key, *args) + + except Exception as e: + if "NOSCRIPT" in str(e): + _sha_tbra, _sha_gcra = None, None + + await _ensure_scripts() + + return await r.evalsha(sha, 1, key, *args) + + raise + + +async def execute_tbra( + key: str, + capacity: int, + rate: int, +) -> tuple[bool, float, int]: + """ + Layer 1: Execute TBRA script. + + Args: + key: Full Redis key + capacity: capacity * 1000 + rate: tokens per step * 1000 + + Returns: + (allow, tokens_remaining, retry_steps) + """ + sha_tbra, _ = await _ensure_scripts() + + now_step = _now_step() + + result = await _exec_script(sha_tbra, key, capacity, rate, now_step) + + allow, tokens_scaled, retry_steps = result + + return bool(allow), tokens_scaled / _SCALE, int(retry_steps) + + +async def execute_gcra( + key: str, + interval: int, + tolerance: int, +) -> tuple[bool, float, int]: + """ + Layer 1: Execute GCRA script. + + Args: + key: Full Redis key + interval: Steps between requests at steady rate + tolerance: Burst tolerance in steps + + Returns: + (allow, tokens_remaining, retry_steps) + """ + _, sha_gcra = await _ensure_scripts() + + now_step = _now_step() + + result = await _exec_script(sha_gcra, key, interval, tolerance, now_step) + + allow, tokens_remaining, retry_steps = result + + return bool(allow), float(tokens_remaining), int(retry_steps) + + +# ============================================================================= +# Layer 2: Library API +# ============================================================================= + + +class Algorithm(Enum): + TBRA = "tbra" + GCRA = "gcra" + + +class FailureMode(Enum): + OPEN = "open" + CLOSED = "closed" + + +@dataclass(frozen=True) +class ThrottleResult: + key: str + allow: bool + tokens_remaining: Optional[float] + retry_after_ms: Optional[int] + + @property + def retry_after_seconds(self) -> float: + if not self.retry_after_ms or self.retry_after_ms <= 0: + return 0.0 + return self.retry_after_ms / 1000.0 + + +def _build_key(key: Union[str, dict]) -> str: + """ + Build Redis key from str or dict. + + If str: use as-is + If dict: join sorted key-value pairs with ':' + + Examples: + _build_key("global") -> "throttle:global" + _build_key({"org": "abc123"}) -> "throttle:org:abc123" + _build_key({"ep": "users", "org": "abc123"}) -> "throttle:ep:users:org:abc123" + """ + if isinstance(key, dict): + key_str = ":".join(f"{k}:{v}" for k, v in sorted(key.items())) + elif isinstance(key, str): + key_str = key + else: + raise TypeError("key must be str or dict") + + return f"throttle:{key_str}" + + +def _key_to_str(key: Union[str, dict]) -> str: + """Convert key to string for result/logging.""" + if isinstance(key, dict): + return ":".join(f"{k}:{v}" for k, v in sorted(key.items())) + + return key + + +def _to_tbra_params(max_capacity: int, refill_rate: int) -> tuple[int, int]: + """ + Convert to TBRA params. + + Args: + max_capacity: Burst size (tokens) + refill_rate: Tokens per minute + + Returns: + (capacity, rate) + """ + rate = (refill_rate * _SCALE * TIME_STEP_MS) // 60000 + + if rate < 1: + rate = 1 + + capacity = max_capacity * _SCALE + + return capacity, rate + + +def _to_gcra_params(max_capacity: int, refill_rate: int) -> tuple[int, int]: + """ + Convert to GCRA params. + + Args: + max_capacity: Burst tolerance (requests) + refill_rate: Requests per minute + + Returns: + (interval, tolerance) + """ + interval = 60000 // (refill_rate * TIME_STEP_MS) if refill_rate > 0 else 1 + + if interval < 1: + interval = 1 + + tolerance = max_capacity * interval + + return interval, tolerance + + +def _failure_result(key: str, failure_mode: FailureMode) -> ThrottleResult: + return ThrottleResult( + allow=(failure_mode == FailureMode.OPEN), + tokens_remaining=None, + retry_after_ms=None, + key=key, + ) + + +async def check_throttle( + key: Union[str, dict], + max_capacity: int, + refill_rate: int, + algorithm: Algorithm = Algorithm.TBRA, + failure_mode: FailureMode = FailureMode.OPEN, +) -> ThrottleResult: + """ + Layer 2: Check rate limit and consume one token. + + Args: + key: Bucket key - str or dict + str: "global", "org:123", "ep:users:org:123" + dict: {"org": "123"}, {"ep": "users", "org": "123"} + max_capacity: Burst size / max tokens + refill_rate: Tokens per minute + algorithm: TBRA or GCRA + failure_mode: OPEN (allow) or CLOSED (deny) on Redis failure + + Returns: + ThrottleResult with decision and timing + """ + full_key = _build_key(key) + key_str = _key_to_str(key) + + try: + if algorithm == Algorithm.TBRA: + max_cap_s, refill_s = _to_tbra_params( + max_capacity, + refill_rate, + ) + + allow, tokens, retry_steps = await execute_tbra( + full_key, + max_cap_s, + refill_s, + ) + + return ThrottleResult( + allow=allow, + tokens_remaining=max(0.0, tokens), + retry_after_ms=retry_steps * TIME_STEP_MS, + key=key_str, + ) + + elif algorithm == Algorithm.GCRA: + interval, tolerance = _to_gcra_params( + max_capacity, + refill_rate, + ) + + allow, tokens_remaining, retry_steps = await execute_gcra( + full_key, + interval, + tolerance, + ) + + return ThrottleResult( + allow=allow, + tokens_remaining=tokens_remaining, + retry_after_ms=retry_steps * TIME_STEP_MS, + key=key_str, + ) + + else: + log.warning("[throttle] Unknown algorithm", algorithm=algorithm) + + return _failure_result(key_str, failure_mode) + + except Exception: + log.warning("[throttle] Unexpected error", key=key_str, exc_info=True) + + return _failure_result(key_str, failure_mode) + + +# ============================================================================= +# Layer 2: Batch API +# ============================================================================= + + +async def check_throttles( + checks: list[tuple[Union[str, dict], int, int]], + algorithm: Algorithm = Algorithm.TBRA, + failure_mode: FailureMode = FailureMode.OPEN, +) -> list[ThrottleResult]: + """ + Check multiple rate limits in a pipeline. + + Args: + checks: List of (key, max_capacity, refill_rate) where key is str or dict + algorithm: TBRA or GCRA + failure_mode: OPEN or CLOSED on failure + + Returns: + List of ThrottleResult + """ + if not checks: + return [] + + if algorithm not in (Algorithm.TBRA, Algorithm.GCRA): + log.warning("[throttle] [batch] Unknown algorithm", algorithm=algorithm) + + return [_failure_result(_key_to_str(key), failure_mode) for key, _, _ in checks] + + # Pre-process keys + processed = [] + + for key, max_capacity, refill_rate in checks: + full_key = _build_key(key) + key_str = _key_to_str(key) + + processed.append((full_key, key_str, max_capacity, refill_rate)) + + try: + r = _get_redis() + + sha_tbra, sha_gcra = await _ensure_scripts() + + sha = sha_tbra if algorithm == Algorithm.TBRA else sha_gcra + + now_step = _now_step() + + pipe = r.pipeline(transaction=False) + + for full_key, _, max_capacity, refill_rate in processed: + if algorithm == Algorithm.TBRA: + max_cap_s, refill_s = _to_tbra_params(max_capacity, refill_rate) + pipe.evalsha(sha, 1, full_key, max_cap_s, refill_s, now_step) + + elif algorithm == Algorithm.GCRA: + interval, tolerance = _to_gcra_params(max_capacity, refill_rate) + pipe.evalsha(sha, 1, full_key, interval, tolerance, now_step) + + raw_results = await pipe.execute() + + results = [] + + for (_, key_str, max_capacity, _), raw in zip(processed, raw_results): + if algorithm == Algorithm.TBRA: + allow, tokens_scaled, retry_steps = raw + results.append( + ThrottleResult( + allow=bool(allow), + tokens_remaining=max(0.0, tokens_scaled / _SCALE), + retry_after_ms=int(retry_steps) * TIME_STEP_MS, + key=key_str, + ) + ) + + elif algorithm == Algorithm.GCRA: + allow, tokens_remaining, retry_steps = raw + results.append( + ThrottleResult( + allow=bool(allow), + tokens_remaining=float(tokens_remaining), + retry_after_ms=int(retry_steps) * TIME_STEP_MS, + key=key_str, + ) + ) + + return results + + except Exception: + log.warning("[throttle] [batch] Unexpected error", exc_info=True) + + return [_failure_result(ks, failure_mode) for _, ks, _, _ in processed] + + +# ============================================================================= +# Layer 2: Utilities +# ============================================================================= + + +async def peek_throttle(key: Union[str, dict]) -> Optional[dict]: + """View bucket state without consuming.""" + try: + r = _get_redis() + + full_key = _build_key(key) + val = await r.get(full_key) + + if not val: + return None + + val_str = val.decode() if isinstance(val, bytes) else val + + if "|" in val_str: + tokens_str, ts_str = val_str.split("|") + + return {"tokens": float(tokens_str) / _SCALE, "last_step": int(ts_str)} + + else: + return {"tat": int(val_str)} + + except Exception as e: + log.warning("[throttle] PEEK ERROR", error=str(e)) + + return None + + +async def reset_throttle(key: Union[str, dict]) -> bool: + """Delete bucket.""" + try: + r = _get_redis() + + full_key = _build_key(key) + + return await r.delete(full_key) > 0 + + except Exception as e: + log.warning("[throttle] RESET ERROR", error=str(e)) + + return False + + +# ============================================================================= +# Layer 3 Helpers: For building middleware/decorators +# ============================================================================= + +# Type for param resolver callback +ThrottleParamsResolver = Callable[ + [Any], # request or context + Awaitable[tuple[Union[str, dict], int, int]], # (key, max_capacity, refill_rate) +] + +# Default params for simple usage +DEFAULT_MAX_CAPACITY = 1000 +DEFAULT_REFILL_RATE = 100 # per minute +DEFAULT_ALGORITHM = Algorithm.TBRA diff --git a/docs/designs/api-rate-limiting/PR.md b/docs/designs/api-rate-limiting/PR.md new file mode 100644 index 0000000000..9cdae34fa3 --- /dev/null +++ b/docs/designs/api-rate-limiting/PR.md @@ -0,0 +1,100 @@ +# PR - API Rate Limiting (Throttling) + +## Executive summary + +This PR introduces a Redis-based API rate limiting system with support for multiple algorithms +(TBRA and GCRA), plan-based throttle policies via the entitlements system, and middleware +integration for automatic enforcement. The implementation follows a three-layer architecture +separating Lua scripts, library API, and middleware concerns. + +## Change inventory (organized by area) + +### OSS backend (throttling library) + +- New `throttling.py` utility with Redis Lua scripts for TBRA and GCRA algorithms. +- Layer 1: Raw script execution (`execute_tbra`, `execute_gcra`). +- Layer 2: Public API (`check_throttle`, `check_throttles`) with key building and failure handling. +- Utility functions: `peek_throttle`, `reset_throttle` for debugging/admin. + +Key files: +- `api/oss/src/utils/throttling.py` + +### EE backend (middleware and entitlements) + +- New throttling middleware that enforces rate limits after authentication. +- Integration with entitlements system via `Tracker.THROTTLES`. +- Plan-based throttle policies for HOBBY, PRO, BUSINESS tiers. +- Category-based endpoint grouping (STANDARD, CORE_FAST, TRACING_SLOW, etc.). +- Subscription caching for plan resolution. + +Key files: +- `api/ee/src/services/throttling_service.py` +- `api/ee/src/core/entitlements/types.py` + +### Entitlements expansion + +- New `Tracker.THROTTLES` tracker type for rate limit policies. +- New types: `Bucket`, `Throttle`, `Mode`, `Category`, `Method`. +- `ENDPOINTS` registry mapping categories to endpoint patterns. +- Throttle definitions per plan in `ENTITLEMENTS` dict. + +Key files: +- `api/ee/src/core/entitlements/types.py` + +### Documentation + +- Design specs covering concepts, policies, algorithms, implementation, and middleware. +- QA checklist for manual testing. + +Key files: +- `docs/designs/api-rate-limiting/README.md` +- `docs/designs/api-rate-limiting/throttling.*.specs.md` +- `docs/designs/api-rate-limiting/QA.md` + +## Behavior and policy changes + +- Authenticated requests are rate-limited based on organization's subscription plan. +- Different limits apply to different endpoint categories (STANDARD, FAST, SLOW). +- SLOW category endpoints (analytics, queries) have burst capacity but very low refill rate. +- Admin users bypass rate limiting. +- Unauthenticated requests are not rate-limited (IP-based limiting not yet implemented). +- On Redis failure, requests are allowed (fail-open). + +## Rate limits by plan + +| Plan | Category | Capacity | Rate/min | +|------|----------|----------|----------| +| HOBBY | STANDARD | 120 | 120 | +| HOBBY | FAST | 1,200 | 1,200 | +| HOBBY | SLOW | 120 | 1 | +| PRO | STANDARD | 360 | 360 | +| PRO | FAST | 3,600 | 3,600 | +| PRO | SLOW | 180 | 1 | +| BUSINESS | STANDARD | 3,600 | 3,600 | +| BUSINESS | FAST | 36,000 | 36,000 | +| BUSINESS | SLOW | 1,800 | 1 | + +## Response format + +429 responses include: +- `Retry-After` header with seconds until retry +- `X-RateLimit-Limit` header with bucket capacity +- `X-RateLimit-Remaining` header with remaining tokens +- Body: `{"detail": "rate_limit_exceeded"}` + +## Risks and considerations + +- Redis dependency: Rate limiting requires Redis volatile instance. Failure mode is open (allow). +- Algorithm choice: GCRA is now the default. TBRA available but not used in current policies. +- No IP limiting: Unauthenticated endpoints are unprotected. Future work needed. +- SLOW category limits: Very restrictive (1/min after burst). May need adjustment based on usage. +- Cache invalidation: Plan changes require cache expiry (TTL-based, no active invalidation). + +## Suggested validation + +- Follow the QA checklist in `docs/designs/api-rate-limiting/QA.md`. +- Test rate limiting for each plan tier. +- Verify 429 responses include correct headers. +- Test Redis failure behavior (should allow requests). +- Verify admin bypass works correctly. +- Load test to confirm limits are enforced accurately. diff --git a/docs/designs/api-rate-limiting/QA.md b/docs/designs/api-rate-limiting/QA.md new file mode 100644 index 0000000000..47ae53192b --- /dev/null +++ b/docs/designs/api-rate-limiting/QA.md @@ -0,0 +1,174 @@ +# API Rate Limiting QA + +Manual QA checklist for rate limiting functionality. Use alongside the specs in this folder. + +--- + +## Basic Rate Limiting + +### Single request within limit +- Preconditions: Fresh bucket, authenticated user. +- Steps: Make single API request. +- Expected: Request succeeds, no rate limit headers on success. + +### Exceed rate limit +- Preconditions: Fresh bucket with known capacity. +- Steps: Make requests exceeding capacity in quick succession. +- Expected: 429 response with `Retry-After`, `X-RateLimit-Limit`, `X-RateLimit-Remaining` headers. + +### Wait and retry +- Preconditions: Rate limited (received 429). +- Steps: Wait for `Retry-After` seconds, retry request. +- Expected: Request succeeds. + +--- + +## Plan-Based Limits + +### HOBBY tier limits +- Preconditions: Organization on HOBBY plan. +- Steps: Test STANDARD, FAST, and SLOW category endpoints. +- Expected: Limits match HOBBY tier (120/120, 1200/1200, 120/1). + +### PRO tier limits +- Preconditions: Organization on PRO plan. +- Steps: Test STANDARD, FAST, and SLOW category endpoints. +- Expected: Limits match PRO tier (360/360, 3600/3600, 180/1). + +### BUSINESS tier limits +- Preconditions: Organization on BUSINESS plan. +- Steps: Test STANDARD, FAST, and SLOW category endpoints. +- Expected: Limits match BUSINESS tier (3600/3600, 36000/36000, 1800/1). + +### Plan upgrade reflects in limits +- Preconditions: Organization on HOBBY, then upgraded to PRO. +- Steps: Exhaust HOBBY limit, upgrade plan, wait for cache expiry, retry. +- Expected: New PRO limits apply after cache refresh. + +--- + +## Category-Based Limits + +### STANDARD category +- Preconditions: Authenticated user. +- Steps: Call any non-categorized endpoint. +- Expected: STANDARD limits apply. + +### FAST category (CORE_FAST, TRACING_FAST, SERVICES_FAST) +- Preconditions: Authenticated user. +- Steps: Call categorized endpoints (e.g., `POST */retrieve`, `POST /otlp/v1/traces`). +- Expected: FAST limits apply (higher than STANDARD). + +### SLOW category (CORE_SLOW, TRACING_SLOW, SERVICES_SLOW) +- Preconditions: Authenticated user. +- Steps: Call slow endpoints (e.g., `POST /tracing/*/query`, `POST /tracing/spans/analytics`). +- Expected: SLOW limits apply (burst capacity, then 1/min refill). + +### Multiple categories isolated +- Preconditions: Authenticated user. +- Steps: Exhaust STANDARD limit, then call FAST endpoint. +- Expected: FAST endpoint succeeds (separate bucket). + +--- + +## Bypass and Edge Cases + +### Admin bypass +- Preconditions: Admin user. +- Steps: Make requests exceeding normal limits. +- Expected: All requests succeed, no rate limiting. + +### Unauthenticated requests +- Preconditions: No authentication. +- Steps: Call public endpoints. +- Expected: Requests not rate-limited (IP limiting not implemented). + +### Missing organization +- Preconditions: Request without organization context. +- Steps: Make API request. +- Expected: Request not rate-limited, passes through. + +--- + +## Redis Failure Modes + +### Redis unavailable +- Preconditions: Redis down or unreachable. +- Steps: Make API requests. +- Expected: Requests succeed (fail-open mode). + +### Redis timeout +- Preconditions: Redis slow (>100ms response). +- Steps: Make API requests. +- Expected: Requests succeed after timeout, logged warning. + +### Redis recovers +- Preconditions: Redis was down, now recovered. +- Steps: Make API requests. +- Expected: Rate limiting resumes, buckets start fresh. + +--- + +## Response Validation + +### 429 response format +- Preconditions: Rate limited. +- Steps: Examine 429 response. +- Expected: + - Status: 429 + - Body: `{"detail": "rate_limit_exceeded"}` + - Header: `Retry-After: ` + - Header: `X-RateLimit-Limit: ` + - Header: `X-RateLimit-Remaining: 0` + +### Retry-After accuracy +- Preconditions: Rate limited with known refill rate. +- Steps: Note `Retry-After` value, wait that duration, retry. +- Expected: Request succeeds within 1-2 seconds of indicated time. + +--- + +## Bucket Key Isolation + +### Different organizations isolated +- Preconditions: Two organizations, same plan. +- Steps: Exhaust limit for org A, make request for org B. +- Expected: Org B request succeeds (separate bucket). + +### Same organization, different categories isolated +- Preconditions: Single organization. +- Steps: Exhaust STANDARD limit, call FAST endpoint. +- Expected: FAST request succeeds (separate bucket). + +### Plan change creates new bucket +- Preconditions: Organization changes plan. +- Steps: Exhaust limit on old plan, upgrade, wait for cache, retry. +- Expected: New bucket with new plan limits. + +--- + +## Algorithm Behavior (GCRA) + +### Smooth scheduling +- Preconditions: Fresh bucket. +- Steps: Make requests at steady rate just below limit. +- Expected: All requests succeed, no bursts of 429s. + +### Burst then throttle +- Preconditions: Fresh bucket with burst capacity. +- Steps: Make burst of requests, then continue at rate. +- Expected: Burst succeeds, then smoothly throttled to rate. + +### Remaining tokens accuracy +- Preconditions: Fresh bucket. +- Steps: Make N requests, check `X-RateLimit-Remaining` on 429. +- Expected: Remaining decreases predictably. + +--- + +## Notes + +- SLOW category has very restrictive refill (1/min). Test burst exhaustion carefully. +- Cache TTL for subscription data affects plan change propagation. +- Use Redis CLI (`redis-cli KEYS "throttle:*"`) to inspect bucket state. +- Use `peek_throttle` utility for debugging bucket state without consuming tokens. diff --git a/docs/designs/api-rate-limiting/README.md b/docs/designs/api-rate-limiting/README.md new file mode 100644 index 0000000000..b4610f7a4d --- /dev/null +++ b/docs/designs/api-rate-limiting/README.md @@ -0,0 +1,82 @@ +# API Rate Limiting + +Redis-based rate limiting for API protection with support for multiple algorithms and flexible policy definitions. + +## Overview + +This system provides distributed rate limiting using Redis with: +- **Two algorithms**: TBRA (Token Bucket) and GCRA (Generic Cell Rate Algorithm) +- **PAR(C) policies**: Principal, Action, Resource, Condition model +- **Three-layer architecture**: Scripts → Library → Middleware +- **Fail-safe modes**: Open or closed on Redis failure + +## Documentation + +| Document | Description | +|----------|-------------| +| [throttling.concepts.specs.md](throttling.concepts.specs.md) | Core vocabulary: principals, categories, plans, enforcement | +| [throttling.policies.specs.md](throttling.policies.specs.md) | PAR(C) policy model, scoping modes, examples | +| [throttling.algorithms.specs.md](throttling.algorithms.specs.md) | TBRA vs GCRA algorithms, trade-offs, parameters | +| [throttling.implementation.specs.md](throttling.implementation.specs.md) | Three-layer architecture, Lua scripts, Python API | +| [throttling.middleware.specs.md](throttling.middleware.specs.md) | Middleware design, entitlements integration, plan-based limits | + +## Quick Start + +```python +from oss.src.utils.throttling import check_throttle, Algorithm + +# Simple usage +result = await check_throttle("global", max_capacity=100, refill_rate=60) + +# With dict key +result = await check_throttle({"org": org_id}, max_capacity=100, refill_rate=60) + +# With GCRA algorithm (default) +result = await check_throttle( + {"org": org_id, "policy": "cats:standard"}, + max_capacity=50, + refill_rate=30, + algorithm=Algorithm.GCRA, +) + +if not result.allow: + # Return 429 with Retry-After header + retry_after = result.retry_after_seconds +``` + +## Key Concepts + +**Principal**: Who is being limited (organization_id; IP not yet implemented) + +**Categories**: Named endpoint groups (STANDARD, CORE_FAST, TRACING_SLOW, etc.) + +**Plan**: Subscription tier that determines limits (HOBBY, PRO, BUSINESS) + +**Policy**: PAR(C) rule mapping principal + plan + categories → bucket parameters + +## Bucket Key Format + +``` +throttle:organization:{org_id}:plan:{plan}:policy:{slug} +``` + +Examples: +- `throttle:organization:org_abc123:plan:cloud_v0_pro:policy:cats:standard` +- `throttle:organization:org_abc123:plan:cloud_v0_hobby:policy:cats:core_fast,services_fast,tracing_fast` + +## Response Headers + +On 429 response: +``` +Retry-After: 2 +X-RateLimit-Limit: 100 +X-RateLimit-Remaining: 0 +``` + +## Design Principles + +1. **Redis is enforcement, not billing** — billing-grade quotas exist elsewhere +2. **Organization-first** — primary identifier is org_id, not user_id +3. **Fail-open by default** — allow requests when Redis is unavailable +4. **Atomic operations** — Lua scripts ensure correctness under concurrency +5. **GCRA by default** — smooth scheduling, minimal state, predictable behavior diff --git a/docs/designs/api-rate-limiting/throttling.algorithms.specs.md b/docs/designs/api-rate-limiting/throttling.algorithms.specs.md new file mode 100644 index 0000000000..a4204e3fd7 --- /dev/null +++ b/docs/designs/api-rate-limiting/throttling.algorithms.specs.md @@ -0,0 +1,245 @@ +# Algorithms: TBRA vs GCRA + +Two algorithms are supported, both using the same interface: `max_capacity` and `refill_rate`. + +--- + +## TBRA (Token Bucket Rate Algorithm) + +### Intuition + +Token bucket models rate as: +- A bucket has **capacity** (max tokens) that caps burst +- Tokens **refill continuously** at `refill_rate` per minute +- Each request **consumes 1 token** +- If insufficient tokens → request denied until tokens refill + +**Key property**: Tokens can be "banked" while idle. A quiet principal can later spend a burst up to the bucket capacity. + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `max_capacity` | Burst size (max tokens) | +| `refill_rate` | Tokens per minute | + +### State in Redis + +``` +"tokens_scaled|last_step" +``` + +- `tokens_scaled`: Current tokens × 1000 (fixed-point) +- `last_step`: Last update time in steps + +### Algorithm + +``` +on request at now_step: + elapsed = now_step - last_step + tokens = tokens + elapsed * refill_per_step + tokens = min(max_capacity, tokens) + + tokens = tokens - 1 + if tokens >= 0: + allow = true + retry_after = 0 + else: + allow = false + retry_after = ceil(-tokens / refill_per_step) + + store(tokens, now_step) + return (allow, tokens, retry_after) +``` + +### What You Get + +- Very intuitive semantics ("banked burst") +- Meaningful "remaining tokens" for headers +- Tokens accumulate when idle + +### What You Lose + +- Slightly more CPU (parsing, refill math) +- State is slightly larger (two values) + +--- + +## GCRA (Generic Cell Rate Algorithm) + +### Intuition + +GCRA is a leaky bucket / scheduling approach: +- Enforces average **spacing between requests** +- Allows bursts via **tolerance** (how early a request can be) +- Stores only one value: **TAT** (Theoretical Arrival Time) + +**Key property**: Very smooth, predictable enforcement with minimal state. + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `max_capacity` | Burst tolerance (in requests) | +| `refill_rate` | Requests per minute | + +Derived: +- `interval = 60000 / (refill_rate * TIME_STEP_MS)` — steps between requests +- `tolerance = max_capacity * interval` — burst tolerance in steps + +### State in Redis + +``` +"tat" +``` + +- `tat`: Theoretical arrival time (single integer) + +### Algorithm + +``` +on request at now_step: + tat = get(key) or now_step + limit = tat - tolerance + + if now_step < limit: + allow = false + retry_after = limit - now_step + new_tat = tat + else: + allow = true + retry_after = 0 + new_tat = max(tat, now_step) + interval + + store(new_tat) + return (allow, retry_after) +``` + +### What You Get + +- Extremely fast (one integer state) +- Smooth, predictable behavior +- Very stable under high load +- Clear retry-after calculation + +### What You Lose + +- No explicit "banked tokens" concept +- Burst is tolerance, not stored tokens + +--- + +## Comparison + +| Property | TBRA | GCRA | +|----------|------|------| +| State size | 2 values (tokens + ts) | 1 value (tat) | +| CPU | Slightly more (refill math) | Minimal | +| "Remaining tokens" | Yes (natural) | Yes (computed from tolerance) | +| Burst semantics | Banked tokens | Tolerance window | +| After idle period | Full bucket available | Tolerance available | +| Headers UX | `X-RateLimit-Remaining` works | `X-RateLimit-Remaining` works | + +### Semantics Difference + +**TBRA**: If idle for 10 minutes with 60 tokens/min refill and capacity 100: +- Bucket fills to 100 tokens +- Can make 100 requests immediately + +**GCRA**: Same parameters, idle for 10 minutes: +- Can make up to `max_capacity` requests with tolerance +- Then must respect spacing interval + +In most real workloads, these feel identical. The difference shows in long-idle-then-burst patterns. + +--- + +## Best Fit Guidance + +**Choose TBRA if**: +- "Banked burst after idle" is a product expectation +- You want strong "remaining tokens" UX +- Clients expect `X-RateLimit-Remaining` header + +**Choose GCRA if**: +- Maximum throughput and simplest implementation +- Smooth scheduling and predictable retry-after +- Burst semantics as tolerance is acceptable + +**Pragmatic strategy**: +- Use GCRA everywhere by default (fast, smooth) +- Reserve TBRA for places where "banked burst" is a deliberate customer promise + +--- + +## Performance Optimizations + +Both algorithms use these optimizations: + +### Time Quantization + +Use 1-second steps instead of milliseconds: +```python +TIME_STEP_MS = 1000 +now_step = int(time.time() * 1000) // TIME_STEP_MS +``` + +Effect: Refill happens in chunks, not perfectly continuous. Acceptable for most use cases. + +### Fixed-Point Arithmetic + +Store tokens as integers scaled by 1000: +```python +_SCALE = 1000 +tokens_scaled = capacity * _SCALE +``` + +Effect: Faster math, no floating-point drift. + +### App-Provided Time + +Pass `now` from the API instead of calling Redis `TIME`: +```python +now_step = _now_step() # computed in Python +``` + +Effect: Removes one Redis call. Small clock skew is acceptable for enforcement. + +### Hardcoded TTL + +TTL is hardcoded in the script (60 minutes): +```lua +redis.call('SET', key, value, 'PX', 3600000) +``` + +Effect: One less parameter to pass. + +--- + +## Script Contracts + +### TBRA Script + +**Inputs**: +- `KEYS[1]`: bucket key +- `ARGV[1]`: max_cap_scaled +- `ARGV[2]`: refill_per_step_scaled +- `ARGV[3]`: now_step + +**Outputs**: +- `[0]`: allow (0 or 1) +- `[1]`: tokens_scaled (current tokens × 1000) +- `[2]`: retry_steps (steps until allowed) + +### GCRA Script + +**Inputs**: +- `KEYS[1]`: bucket key +- `ARGV[1]`: interval (steps between requests) +- `ARGV[2]`: tolerance (burst tolerance in steps) +- `ARGV[3]`: now_step + +**Outputs** (same order as TBRA): +- `[0]`: allow (0 or 1) +- `[1]`: remaining (remaining burst capacity in requests) +- `[2]`: retry_steps (steps until allowed) diff --git a/docs/designs/api-rate-limiting/throttling.concepts.specs.md b/docs/designs/api-rate-limiting/throttling.concepts.specs.md new file mode 100644 index 0000000000..763295a2f6 --- /dev/null +++ b/docs/designs/api-rate-limiting/throttling.concepts.specs.md @@ -0,0 +1,103 @@ +# Core Concepts + +## Goals + +- **Protect reliability**: Prevent overload of the API and downstream dependencies +- **Fairness**: Prevent any single tenant from consuming disproportionate capacity +- **Plan-aware enforcement**: Apply different limits based on subscription plan +- **Endpoint-aware enforcement**: Apply different limits based on route group +- **Support bursts**: Allow short-term bursts while preserving average rate +- **Distributed correctness**: Work correctly with multiple API instances + +## Non-Goals + +- **Billing-grade counting**: Monthly quotas and invoices are handled elsewhere +- **User-based limits**: Primary identifier is organization_id, not user_id +- **Perfect IP fairness**: IP is a guardrail for unauthenticated routes only + +--- + +## Principal (Who) + +A **principal** is "who is being limited". + +| Type | Usage | Example | Status | +|------|-------|---------|--------| +| `organization_id` | Primary identifier for authenticated traffic | `org_abc123` | ✅ Implemented | +| `ip` | Unauthenticated routes or secondary guardrail | `192.168.1.100` | ⏳ Not yet implemented | + +**Note**: IP limits are inaccurate due to NAT, proxies, and mobile networks. Use IP limits primarily as a safety net. + +**Current implementation**: Only `organization_id` is supported. Unauthenticated requests bypass rate limiting. + +--- + +## Endpoint Groups (What) + +Endpoint groups are **named collections** of related endpoints defined in a central registry. + +| Group | Description | Endpoints | +|-------|-------------|-----------| +| `otlp` | OpenTelemetry ingest | `POST /v1/otlp/traces` | +| `queries` | Span and analytics queries | `/v1/spans/query`, `/v1/analytics/query` | +| `public` | Unauthenticated endpoints | Public health checks, docs | +| `auth` | Authentication endpoints | `/v1/auth/*`, `/v1/supertokens/*` | +| `registry` | Entity fetch/retrieve | Testsets, apps, evals, traces lookups | + +Groups are defined once in configuration and referenced by name in policies. + +--- + +## Plan (Condition) + +A **plan** determines which contract applies to a principal. + +| Plan | Description | +|------|-------------| +| `free` | Free tier with basic limits | +| `pro` | Professional tier with higher limits | +| `enterprise` | Enterprise tier with custom limits | +| `anonymous` | Unauthenticated traffic | + +The plan is resolved per request from the organization's configuration. + +--- + +## Enforcement + +The runtime act of checking policies and deciding: + +1. **Resolve principal** — organization_id for authenticated, IP for unauthenticated +2. **Resolve endpoint** — determine which groups apply +3. **Resolve plan** — look up organization's subscription +4. **Select policies** — filter applicable policies +5. **Enforce atomically** — check all applicable buckets +6. **Return decision** — allow or deny (429) + +### Deny Semantics + +- If **any** applicable policy denies → request is denied +- System records **which policy denied** (the limiting dimension) +- `Retry-After` header computed from the denying policy + +--- + +## Why Redis + +- **Centralized shared state** across many API instances +- **Atomic operations** via Lua scripts +- **Built-in TTLs** for automatic key expiration +- **High throughput** and low latency + +Redis state is typically **volatile** (acceptable reset on restart). Billing-grade persistence is handled elsewhere. + +--- + +## Failure Modes + +| Mode | Behavior | Use Case | +|------|----------|----------| +| `open` | Allow requests when Redis unavailable | Default, prevents outage | +| `closed` | Deny requests when Redis unavailable | Expensive/critical endpoints | + +The choice should be explicit and consistent per policy or globally configured. diff --git a/docs/designs/api-rate-limiting/throttling.implementation.specs.md b/docs/designs/api-rate-limiting/throttling.implementation.specs.md new file mode 100644 index 0000000000..551a86721e --- /dev/null +++ b/docs/designs/api-rate-limiting/throttling.implementation.specs.md @@ -0,0 +1,407 @@ +# Implementation + +Three-layer architecture for rate limiting with Redis. + +--- + +## Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Layer 3: User Code (Middleware/Decorators) │ +│ - Resolves key from request context │ +│ - Resolves params from plan lookup │ +│ - Handles 429 response │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Layer 2: Library API │ +│ check_throttle(key, max_capacity, refill_rate, ...) │ +│ - Accepts key as str or dict │ +│ - Converts to algorithm-specific params │ +│ - Handles Redis failures │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Layer 1: Lua Scripts │ +│ _exec_tbra(key, max_cap_scaled, refill_scaled) │ +│ _exec_gcra(key, interval, tolerance) │ +│ - Uses current time from caller │ +│ - Atomic read-modify-write │ +└─────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌───────────────┐ + │ Redis │ + └───────────────┘ +``` + +--- + +## Layer 1: Lua Scripts + +### TBRA Script + +```lua +local key = KEYS[1] +local max_cap = tonumber(ARGV[1]) +local refill = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) + +local val = redis.call('GET', key) +local tokens, last + +if val then + local sep = string.find(val, '|') + tokens = tonumber(string.sub(val, 1, sep - 1)) + last = tonumber(string.sub(val, sep + 1)) +else + tokens = max_cap + last = now +end + +local elapsed = now - last +if elapsed > 0 then + tokens = tokens + elapsed * refill + if tokens > max_cap then tokens = max_cap end +end + +tokens = tokens - 1000 +local allow = tokens >= 0 and 1 or 0 +local retry = allow == 1 and 0 or math.ceil(-tokens / refill) + +redis.call('SET', key, tokens .. '|' .. now, 'PX', 3600000) + +return {allow, tokens, retry} +``` + +### GCRA Script + +```lua +local key = KEYS[1] +local interval = tonumber(ARGV[1]) +local tolerance = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) + +local tat = tonumber(redis.call('GET', key)) or now + +local limit = tat - tolerance +local allow, retry, new_tat + +if now < limit then + allow = 0 + retry = limit - now + new_tat = tat +else + allow = 1 + retry = 0 + new_tat = (tat > now and tat or now) + interval +end + +redis.call('SET', key, new_tat, 'PX', 3600000) + +return {allow, retry} +``` + +### Script Loading + +Scripts are loaded lazily with SCRIPT LOAD and invoked via EVALSHA: + +```python +async def _ensure_scripts() -> tuple[str, str]: + global _sha_tb, _sha_gcra + if _sha_tb is None or _sha_gcra is None: + r = _get_redis() + _sha_tb = await r.script_load(_LUA_TBRA) + _sha_gcra = await r.script_load(_LUA_GCRA) + return str(_sha_tb), str(_sha_gcra) +``` + +### NOSCRIPT Handling + +When Redis restarts, scripts are flushed. Handle NOSCRIPT by reloading: + +```python +async def _exec_script(sha: str, key: str, *args) -> list: + try: + return await r.evalsha(sha, 1, key, *args) + except Exception as e: + if "NOSCRIPT" in str(e): + _sha_tb, _sha_gcra = None, None + await _ensure_scripts() + return await r.evalsha(sha, 1, key, *args) + raise +``` + +--- + +## Layer 2: Library API + +### Key Building + +Keys can be `str` or `dict`: + +```python +def _build_key(key: Union[str, dict]) -> str: + if isinstance(key, dict): + key_str = ":".join(f"{k}:{v}" for k, v in sorted(key.items())) + elif isinstance(key, str): + key_str = key + else: + raise TypeError("key must be str or dict") + return f"throttle:{key_str}" +``` + +Examples: +- `"global"` → `throttle:global` +- `{"org": "abc123"}` → `throttle:org:abc123` +- `{"group": "llm", "org": "abc123"}` → `throttle:group:llm:org:abc123` + +### Parameter Conversion + +Convert user-friendly params to algorithm-specific: + +```python +def _to_tbra_params(max_capacity: int, refill_rate: int) -> tuple[int, int]: + max_cap_scaled = max_capacity * _SCALE + refill_per_step_scaled = (refill_rate * TIME_STEP_MS * _SCALE) // 60000 + if refill_per_step_scaled < 1: + refill_per_step_scaled = 1 + return max_cap_scaled, refill_per_step_scaled + +def _to_gcra_params(max_capacity: int, refill_rate: int) -> tuple[int, int]: + interval = 60000 // (refill_rate * TIME_STEP_MS) if refill_rate > 0 else 1 + if interval < 1: + interval = 1 + tolerance = max_capacity * interval + return interval, tolerance +``` + +### Public API + +```python +async def check_throttle( + key: Union[str, dict], + max_capacity: int, + refill_rate: int, + algorithm: Algorithm = Algorithm.TBRA, + failure_mode: FailureMode = FailureMode.OPEN, +) -> ThrottleResult: + """ + Check rate limit and consume one token. + + Args: + key: Bucket key - str or dict + max_capacity: Burst size / max tokens + refill_rate: Tokens per minute + algorithm: TBRA or GCRA + failure_mode: OPEN (allow) or CLOSED (deny) on Redis failure + + Returns: + ThrottleResult with decision and timing + """ +``` + +### Result Type + +```python +@dataclass(frozen=True) +class ThrottleResult: + allow: bool + tokens_remaining: Optional[float] # None for GCRA + retry_after_ms: Optional[int] + key: str + + @property + def retry_after_seconds(self) -> float: + return self.retry_after_ms / 1000.0 if self.retry_after_ms and self.retry_after_ms > 0 else 0.0 +``` + +### Batch API + +Check multiple limits in a single pipeline: + +```python +async def check_throttles( + checks: list[tuple[Union[str, dict], int, int]], + algorithm: Algorithm = Algorithm.TBRA, + failure_mode: FailureMode = FailureMode.OPEN, +) -> list[ThrottleResult]: +``` + +### Utilities + +```python +async def peek_throttle(key: Union[str, dict]) -> Optional[dict]: + """View bucket state without consuming.""" + +async def reset_throttle(key: Union[str, dict]) -> bool: + """Delete bucket.""" +``` + +--- + +## Layer 3: Middleware + +User code resolves key and params from request context: + +```python +from oss.src.utils.throttling import check_throttle, Algorithm, FailureMode + +async def rate_limit_middleware(request, call_next): + # 1. Resolve principal + org_id = request.state.organization_id + + # 2. Resolve plan and get params + plan = await get_plan(org_id) + params = get_rate_limit_params(plan) + + # 3. Check throttle + result = await check_throttle( + key={"org": org_id}, + max_capacity=params.capacity, + refill_rate=params.refill_rate, + ) + + # 4. Handle denial + if not result.allow: + return JSONResponse( + status_code=429, + content={"error": "rate_limit_exceeded"}, + headers={"Retry-After": str(int(result.retry_after_seconds) + 1)}, + ) + + # 5. Proceed + return await call_next(request) +``` + +### Multi-Policy Enforcement + +```python +async def check_all_policies(org_id: str, endpoint_groups: list[str]): + checks = [] + + # Global limit + checks.append(({"org": org_id}, 1000, 500)) + + # Group-specific limits + if "llm" in endpoint_groups: + checks.append(({"group": "llm", "org": org_id}, 500, 300)) + + results = await check_throttles(checks) + + # Deny if any denies + for result in results: + if not result.allow: + return result # Return first denial + + return None # All allowed +``` + +--- + +## Configuration + +### Constants + +```python +# Time step: 1 second +TIME_STEP_MS = 1000 + +# Fixed-point scale for TBRA +_SCALE = 1000 + +# TTL: 60 minutes (hardcoded in scripts) +_TTL_MS = 3600000 + +# Redis socket timeout +THROTTLE_SOCKET_TIMEOUT = 0.1 +``` + +### Redis Client + +```python +def _get_redis() -> Redis: + global _redis + if _redis is None: + _redis = Redis.from_url( + url=env.redis.uri_volatile, + decode_responses=False, + socket_timeout=THROTTLE_SOCKET_TIMEOUT, + ) + return _redis +``` + +--- + +## Failure Handling + +### Fail-Open (Default) + +When Redis is unavailable, allow the request: + +```python +if failure_mode == FailureMode.OPEN: + return ThrottleResult( + allow=True, + tokens_remaining=None, + retry_after_ms=None, + key=key_str, + ) +``` + +### Fail-Closed + +When Redis is unavailable, deny the request: + +```python +if failure_mode == FailureMode.CLOSED: + return ThrottleResult( + allow=False, + tokens_remaining=None, + retry_after_ms=None, + key=key_str, + ) +``` + +--- + +## Response Headers + +On 429 response, include: + +```python +headers = { + "Retry-After": str(int(result.retry_after_seconds) + 1), + "X-RateLimit-Limit": str(max_capacity), + "X-RateLimit-Remaining": str(int(result.tokens_remaining or 0)), +} +``` + +For TBRA, `tokens_remaining` provides a meaningful value. +For GCRA, omit `X-RateLimit-Remaining` or set to 0. + +--- + +## Testing + +### Unit Tests + +- Algorithm math (refill, cap, consume) +- Retry-after correctness +- Edge cases: zero refill, capacity=0 +- Key building (str and dict) + +### Integration Tests + +- Concurrent requests from multiple workers +- Atomicity under contention +- NOSCRIPT recovery after Redis restart + +### Load Tests + +- High QPS, verify Redis latency +- Multiple policies per request +- Hot key behavior diff --git a/docs/designs/api-rate-limiting/throttling.middleware.specs.md b/docs/designs/api-rate-limiting/throttling.middleware.specs.md new file mode 100644 index 0000000000..0ea9f16d1c --- /dev/null +++ b/docs/designs/api-rate-limiting/throttling.middleware.specs.md @@ -0,0 +1,342 @@ +# Throttling Middleware Design + +## Overview + +Add throttling middleware right after auth middleware to enforce rate limits based on: +- **Principal**: organization_id (authenticated) or IP (unauthenticated) +- **Plan**: organization's subscription tier +- **Endpoint group**: which group the request belongs to + +## Request Flow + +``` +Request + │ + ▼ +┌─────────────────────┐ +│ Auth Middleware │ ← Sets request.state.organization_id +└─────────────────────┘ + │ + ▼ +┌─────────────────────┐ +│ Throttle Middleware │ ← NEW: Enforces rate limits +└─────────────────────┘ + │ + ▼ +┌─────────────────────┐ +│ Handler │ +└─────────────────────┘ +``` + +--- + +## Entitlement Expansion + +### Current Entitlement Structure + +```python +class Tracker(str, Enum): + FLAGS = "flags" # Boolean feature flags + COUNTERS = "counters" # Usage counters (monthly) + GAUGES = "gauges" # Resource gauges (hard limits) +``` + +### New Tracker: THROTTLES + +```python +class Tracker(str, Enum): + FLAGS = "flags" + COUNTERS = "counters" + GAUGES = "gauges" + THROTTLES = "throttles" # NEW: Rate limit policies +``` + +### Throttle Type Definition + +```python +class Bucket(BaseModel): + capacity: Optional[int] = None # max tokens in the bucket + rate: Optional[int] = None # tokens added per minute + algorithm: Optional[str] = None + + +class Mode(str, Enum): + INCLUDE = "include" + EXCLUDE = "exclude" + + +class Throttle(BaseModel): + bucket: Bucket + mode: Mode + categories: list[Category] | None = None + endpoints: list[tuple[Method, str]] | None = None +``` + +### Throttle Keys + +Throttle keys are derived from the throttle definition: +- Categories → `cats:{comma-separated-category-values}` +- Endpoints → `eps:{comma-separated-method:path}` +- Fallback → `all` + +--- + +## Entitlements by Plan + +### CLOUD_V0_HOBBY (Free) + +```python +Tracker.THROTTLES: [ + Throttle( + categories=[Category.STANDARD], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=120, rate=120), + ), + Throttle( + categories=[ + Category.CORE_FAST, + Category.TRACING_FAST, + Category.SERVICES_FAST, + ], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=1200, rate=1200), + ), + Throttle( + categories=[ + Category.CORE_SLOW, + Category.TRACING_SLOW, + Category.SERVICES_SLOW, + ], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=120, rate=1), # Burst of 120, then 1/min + ), +] +``` + +### CLOUD_V0_PRO + +```python +Tracker.THROTTLES: [ + Throttle( + categories=[Category.STANDARD], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=360, rate=360), + ), + Throttle( + categories=[ + Category.CORE_FAST, + Category.TRACING_FAST, + Category.SERVICES_FAST, + ], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=3600, rate=3600), + ), + Throttle( + categories=[ + Category.CORE_SLOW, + Category.TRACING_SLOW, + Category.SERVICES_SLOW, + ], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=180, rate=1), # Burst of 180, then 1/min + ), +] +``` + +### CLOUD_V0_BUSINESS / ENTERPRISE + +```python +Tracker.THROTTLES: [ + Throttle( + categories=[Category.STANDARD], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=3600, rate=3600), + ), + Throttle( + categories=[ + Category.CORE_FAST, + Category.TRACING_FAST, + Category.SERVICES_FAST, + ], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=36000, rate=36000), + ), + Throttle( + categories=[ + Category.CORE_SLOW, + Category.TRACING_SLOW, + Category.SERVICES_SLOW, + ], + mode=Mode.INCLUDE, + bucket=Bucket(capacity=1800, rate=1), # Burst of 1800, then 1/min + ), +] +``` + +**Note on SLOW categories**: These endpoints (e.g., analytics queries, span queries) are expensive operations. The rate limit allows an initial burst but then restricts to 1 request per minute to prevent resource exhaustion. + +--- + +## Endpoint Category Registry + +```python +ENDPOINTS: dict[Category, list[tuple[Method, str]]] = { + Category.CORE_FAST: [ + (Method.POST, "*/retrieve"), + ], + Category.TRACING_FAST: [ + (Method.POST, "/otlp/v1/traces"), + ], + Category.TRACING_SLOW: [ + (Method.POST, "/tracing/*/query"), + (Method.POST, "/tracing/spans/analytics"), + ], + Category.SERVICES_FAST: [ + (Method.ANY, "/permissions/verify"), + ], + Category.STANDARD: [], +} +``` + +### Category Resolution + +```python +def resolve_categories(method: str, path: str) -> set[Category]: + categories = set() + for category, endpoints in ENDPOINTS.items(): + for endpoint_method, endpoint_path in endpoints: + if _matches_endpoint(method, path, endpoint_method, endpoint_path): + categories.add(category) + break + + if not categories: + categories.add(Category.STANDARD) + + return categories +``` + +--- + +## Middleware Implementation + +### Throttle Middleware (after auth) + +```python +async def throttling_middleware(request: Request, call_next): + org_id = getattr(request.state, "organization_id", None) + if not org_id: + return await call_next(request) + + # Plan resolution (cached) + plan = await get_cached_plan(org_id) + throttles = ENTITLEMENTS[plan][Tracker.THROTTLES] + + # Resolve categories for this endpoint + categories = resolve_categories(request.method.lower(), request.url.path) + + # Build checks using ThrottleParamsResolver + checks = [] + for throttle in throttles: + if throttle_applies(throttle, categories, request): + resolver = build_params_resolver(org_id, throttle) + checks.append(await resolver(request)) + + # Execute throttle checks + results = await check_throttles(checks) + for result in results: + if not result.allow: + return make_429(result) + + return await call_next(request) +``` + +### Plan Resolution (Cached) + +Plan is cached per organization using `entitlements:subscription`: + +```python +subscription_data = await get_cache( + namespace="entitlements:subscription", + key={"organization_id": org_id}, +) + +if subscription_data is None: + subscription = await subscriptions_service.read(organization_id=org_id) + subscription_data = {"plan": subscription.plan.value, "anchor": subscription.anchor} + await set_cache( + namespace="entitlements:subscription", + key={"organization_id": org_id}, + value=subscription_data, + ) +``` + +### 429 Response + +```python +return JSONResponse( + status_code=429, + content={"detail": "rate_limit_exceeded"}, + headers={"Retry-After": str(int(result.retry_after_seconds) + 1)}, +) +``` + +--- + +## Bucket Key Format + +``` +throttle:organization:{org_id}:plan:{plan}:policy:{slug} +``` + +Examples: +- `throttle:organization:org_abc123:plan:cloud_v0_pro:policy:cats:core_fast,services_fast,tracing_fast` +- `throttle:organization:org_abc123:plan:cloud_v0_hobby:policy:cats:standard` + +Using dict key in check_throttle: +```python +{"organization": org_id, "plan": plan, "policy": "cats:standard"} → "throttle:organization:org_abc123:plan:cloud_v0_hobby:policy:cats:standard" +``` + +### Policy Slug + +The policy slug serves as a unique identifier for logging and metrics. It is derived from the throttle definition: +- Categories → `cats:{comma-separated-sorted-category-values}` +- Endpoints → `epts:{comma-separated-sorted-method:path}` +- Fallback (no categories or endpoints) → `all` + +--- + +## Registration + +```python +app.middleware("http")(authentication_middleware) +app.middleware("http")(throttling_middleware) # right after auth +``` + +--- + +## File Structure + +``` +agenta/api/ +├── oss/src/ +│ └── utils/ +│ └── throttling.py # check_throttle / check_throttles +└── ee/src/ + └── core/ + ├── entitlements/ + │ └── types.py # Tracker.THROTTLES definitions + └── throttle/ + └── middleware.py # throttling_middleware +``` + +--- + +## Summary + +| Component | Location | Purpose | +|-----------|----------|---------| +| `Throttle` type | `ee/src/core/entitlements/types.py` | Define throttle policy structure | +| `ENDPOINTS` | `ee/src/core/entitlements/types.py` | Map endpoints to categories | +| `throttling_middleware` | `ee/src/core/throttle/middleware.py` | Enforce rate limits | +| `Tracker.THROTTLES` | `ee/src/core/entitlements/types.py` | Throttles per plan | diff --git a/docs/designs/api-rate-limiting/throttling.policies.specs.md b/docs/designs/api-rate-limiting/throttling.policies.specs.md new file mode 100644 index 0000000000..9eb428b295 --- /dev/null +++ b/docs/designs/api-rate-limiting/throttling.policies.specs.md @@ -0,0 +1,201 @@ +# Policies (PAR(C) Model) + +## Overview + +A **policy** follows the PAR(C) authorization model: + +| Component | Rate Limiting Mapping | Description | +|-----------|----------------------|-------------| +| **P**rincipal | Identifier type | `organization_id` or `ip` | +| **A**ction | Endpoint scope | Which endpoints: all, include groups, exclude groups | +| **R**esource | (implicit) | The API itself | +| **C**ondition | Plan | Subscription tier that activates this policy | +| **Output** | Bucket key + params | Key format + algorithm parameters | + +--- + +## Action: Endpoint Scope Modes + +Each policy specifies an **action** that defines which endpoints it applies to. + +### Mode 1: All Endpoints + +```yaml +scope: + mode: all +``` + +Applies to every request for the principal. + +### Mode 2: Exclude (All Except) + +```yaml +scope: + mode: exclude + groups: [exports, auth] # exclude these groups + endpoints: [POST /v1/reset] # exclude specific endpoints +``` + +Applies to all endpoints EXCEPT those in the specified groups or explicit endpoints. + +### Mode 3: Include (Only These) + +```yaml +scope: + mode: include + groups: [llm] # only these groups + endpoints: [POST /v1/chat] # only these specific endpoints +``` + +Applies ONLY to endpoints in the specified groups or explicit endpoints. + +--- + +## Output: Bucket Key and Parameters + +Each policy outputs: +- **Bucket key**: Used to identify the Redis bucket +- **Bucket parameters**: `max_capacity`, `refill_rate` + +### Key Format + +``` +throttle:{key-components} +``` + +Key components are built from context: +- Simple: `throttle:global` +- Single dimension: `throttle:org:abc123` +- Multiple dimensions: `throttle:group:llm:org:abc123` + +--- + +## Policy Schema + +```yaml +policy: + slug: string # Unique identifier for logging/metrics + principal_type: org | ip # Who is being limited + condition: + plan: string # Which plan this applies to (* for all) + action: + mode: all | include | exclude + groups: [string] # Group names from registry (optional) + endpoints: [string] # Specific endpoints (optional) + output: + max_capacity: integer # Burst size + refill_rate: integer # Tokens per minute +``` + +--- + +## Example Policies + +### Global Limit for Free Plan + +```yaml +policy: + slug: org-global-free + principal_type: org + condition: + plan: free + action: + mode: all + output: + max_capacity: 100 + refill_rate: 60 +``` + +Key: `throttle:org:{organization_id}` + +### LLM-Specific Limit for Pro Plan + +```yaml +policy: + slug: org-llm-pro + principal_type: org + condition: + plan: pro + action: + mode: include + groups: [llm] + output: + max_capacity: 500 + refill_rate: 300 +``` + +Key: `throttle:group:llm:org:{organization_id}` + +### All Except Exports for Enterprise + +```yaml +policy: + slug: org-non-export-enterprise + principal_type: org + condition: + plan: enterprise + action: + mode: exclude + groups: [exports] + output: + max_capacity: 10000 + refill_rate: 5000 +``` + +Key: `throttle:org:{organization_id}` + +### IP-Based Auth Protection + +```yaml +policy: + slug: ip-auth-default + principal_type: ip + condition: + plan: "*" # applies regardless of plan + action: + mode: include + groups: [auth] + output: + max_capacity: 10 + refill_rate: 5 +``` + +Key: `throttle:group:auth:ip:{ip_address}` + +--- + +## Policy Selection + +For a given request, find all policies where: + +1. `principal_type` matches the request's identifier type +2. `condition.plan` matches the resolved plan (or is wildcard `*`) +3. The endpoint is within the policy's `action` scope + +### Priority (Most Specific First) + +1. Endpoint-specific (include with specific endpoints) +2. Group-specific (include with groups) +3. Exclude-based (all except) +4. Global (all) + +### Multiple Policies + +When multiple policies match: +- All are evaluated +- If any denies → request denied +- The "limiting policy" is recorded for headers/logging + +--- + +## Policy Resolution Inputs + +For a given request, policy selection depends on: + +| Input | Source | +|-------|--------| +| `principal_type` | Authentication state | +| `principal_value` | organization_id or IP | +| `plan` | Organization configuration | +| `endpoint_groups` | Route registry | +| `endpoint_id` | Request method + path | diff --git a/hosting/docker-compose/ee/docker-compose.dev.yml b/hosting/docker-compose/ee/docker-compose.dev.yml index 9382983c04..3003fd4dbd 100644 --- a/hosting/docker-compose/ee/docker-compose.dev.yml +++ b/hosting/docker-compose/ee/docker-compose.dev.yml @@ -358,6 +358,8 @@ services: # === NETWORK ============================================== # networks: - agenta-network + # ports: + # - "6379:6379" # === LIFECYCLE ============================================ # restart: always healthcheck: @@ -385,6 +387,8 @@ services: # === NETWORK ============================================== # networks: - agenta-network + # ports: + # - "6381:6381" # === LIFECYCLE ============================================ # restart: always healthcheck: diff --git a/sdk/agenta/sdk/middleware/auth.py b/sdk/agenta/sdk/middleware/auth.py index efb2218246..58779ee7fc 100644 --- a/sdk/agenta/sdk/middleware/auth.py +++ b/sdk/agenta/sdk/middleware/auth.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Dict from os import getenv from json import dumps @@ -32,10 +32,12 @@ def __init__( self, status_code: int = 401, detail: str = "Unauthorized", + headers: Optional[Dict[str, str]] = None, ) -> None: super().__init__( status_code=status_code, content={"detail": detail}, + headers=headers, ) @@ -44,11 +46,13 @@ def __init__( self, status_code: int = 401, content: str = "Unauthorized", + headers: Optional[Dict[str, str]] = None, ) -> None: super().__init__() self.status_code = status_code self.content = content + self.headers = headers class AuthHTTPMiddleware(BaseHTTPMiddleware): @@ -78,6 +82,7 @@ async def dispatch(self, request: Request, call_next: Callable): return DenyResponse( status_code=deny.status_code, detail=deny.content, + headers=deny.headers, ) except: # pylint: disable=bare-except @@ -188,6 +193,25 @@ async def _get_credentials(self, request: Request) -> Optional[str]: status_code=403, content="Permission denied. Please check your permissions or contact your administrator.", ) + elif response.status_code == 429: + headers = { + key: value + for key, value in { + "Retry-After": response.headers.get("retry-after"), + "X-RateLimit-Limit": response.headers.get( + "x-ratelimit-limit" + ), + "X-RateLimit-Remaining": response.headers.get( + "x-ratelimit-remaining" + ), + }.items() + if value is not None + } + raise DenyException( + status_code=429, + content="API Rate limit exceeded. Please try again later or upgrade your plan.", + headers=headers or None, + ) elif response.status_code != 200: # log.debug( # f"Agenta returned {response.status_code} - Unexpected status code" diff --git a/sdk/agenta/sdk/middleware/vault.py b/sdk/agenta/sdk/middleware/vault.py index 929815c1bc..144a878cf8 100644 --- a/sdk/agenta/sdk/middleware/vault.py +++ b/sdk/agenta/sdk/middleware/vault.py @@ -4,12 +4,13 @@ import httpx from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from agenta.sdk.utils.logging import get_module_logger from agenta.sdk.utils.constants import TRUTHY from agenta.sdk.utils.cache import TTLLRUCache -from agenta.sdk.utils.exceptions import suppress, display_exception +from agenta.sdk.utils.exceptions import display_exception from agenta.client.backend.types import SecretDto as SecretDTO from agenta.client.backend.types import ( StandardProviderDto as StandardProviderDTO, @@ -58,11 +59,13 @@ def __init__( self, status_code: int = 403, content: str = "Forbidden", + headers: Optional[Dict[str, str]] = None, ) -> None: super().__init__() self.status_code = status_code self.content = content + self.headers = headers class VaultMiddleware(BaseHTTPMiddleware): @@ -81,7 +84,7 @@ async def dispatch( ): request.state.vault = {} - with suppress(): + try: secrets, vault_secrets, local_secrets = await self._get_secrets(request) request.state.vault = { @@ -89,6 +92,14 @@ async def dispatch( "vault_secrets": vault_secrets, "local_secrets": local_secrets, } + except DenyException as exc: + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.content}, + headers=exc.headers, + ) + except Exception: # pylint: disable=bare-except + display_exception("Vault: Secrets Exception") return await call_next(request) @@ -144,6 +155,8 @@ async def _get_secrets(self, request: Request) -> tuple[list, list, list]: local_secrets.append(secret.model_dump()) except DenyException as e: # pylint: disable=bare-except + if e.status_code == 429: + raise e log.warning(f"Agenta [secrets] {e.status_code}: {e.content}") allow_secrets = False except Exception: # pylint: disable=bare-except @@ -158,6 +171,26 @@ async def _get_secrets(self, request: Request) -> tuple[list, list, list]: headers=headers, ) + if response.status_code == 429: + headers = { + key: value + for key, value in { + "Retry-After": response.headers.get("retry-after"), + "X-RateLimit-Limit": response.headers.get( + "x-ratelimit-limit" + ), + "X-RateLimit-Remaining": response.headers.get( + "x-ratelimit-remaining" + ), + }.items() + if value is not None + } + raise DenyException( + status_code=429, + content="API Rate limit exceeded. Please try again later or upgrade your plan.", + headers=headers or None, + ) + if response.status_code != 200: vault_secrets = [] @@ -283,6 +316,25 @@ async def _allow_local_secrets(self, credentials): status_code=403, content="Out of credits. Please set your LLM provider API keys or contact support.", ) + elif response.status_code == 429: + headers = { + key: value + for key, value in { + "Retry-After": response.headers.get("retry-after"), + "X-RateLimit-Limit": response.headers.get( + "x-ratelimit-limit" + ), + "X-RateLimit-Remaining": response.headers.get( + "x-ratelimit-remaining" + ), + }.items() + if value is not None + } + raise DenyException( + status_code=429, + content="API Rate limit exceeded. Please try again later or upgrade your plan or upgrade your plan.", + headers=headers or None, + ) elif response.status_code != 200: # log.debug( # f"Agenta returned {response.status_code} - Unexpected status code" @@ -324,7 +376,8 @@ async def _allow_local_secrets(self, credentials): return except DenyException as deny: - _cache.put(_hash, deny) + if deny.status_code != 429: + _cache.put(_hash, deny) raise deny except Exception as exc: # pylint: disable=bare-except diff --git a/sdk/agenta/sdk/middlewares/routing/auth.py b/sdk/agenta/sdk/middlewares/routing/auth.py index f176a118fb..acb2d98e3b 100644 --- a/sdk/agenta/sdk/middlewares/routing/auth.py +++ b/sdk/agenta/sdk/middlewares/routing/auth.py @@ -198,6 +198,11 @@ async def _get_credentials(self, request: Request) -> Optional[str]: status_code=403, content="Permission denied. Please check your permissions or contact your administrator.", ) + elif response.status_code == 429: + raise DenyException( + status_code=429, + content="API Rate limit exceeded. Please try again later or upgrade your plan.", + ) elif response.status_code != 200: # log.debug( # f"Agenta returned {response.status_code} - Unexpected status code" diff --git a/web/oss/src/components/Playground/Components/PlaygroundGenerations/assets/GenerationCompletionRow/ErrorPanel.tsx b/web/oss/src/components/Playground/Components/PlaygroundGenerations/assets/GenerationCompletionRow/ErrorPanel.tsx index b6e676064c..b0bd6ee429 100644 --- a/web/oss/src/components/Playground/Components/PlaygroundGenerations/assets/GenerationCompletionRow/ErrorPanel.tsx +++ b/web/oss/src/components/Playground/Components/PlaygroundGenerations/assets/GenerationCompletionRow/ErrorPanel.tsx @@ -5,8 +5,26 @@ import SharedEditor from "../../../SharedEditor" const GenerationResultUtils = dynamic(() => import("../GenerationResultUtils"), {ssr: false}) export default function ErrorPanel({result}: {result: any}) { - const errorText = + let errorText = typeof result?.error === "string" ? result.error : String(result?.error ?? "Error") + + if ( + errorText === "An unknown error occurred" || + errorText === "Unknown error" || + errorText === "Error" + ) { + const detail = + typeof result?.metadata?.rawError?.detail === "string" + ? result.metadata.rawError.detail + : undefined + if (detail) { + errorText = detail + } + const retryAfter = result?.metadata?.retryAfter + if (retryAfter) { + errorText = `${errorText} Retry after ${retryAfter}s.` + } + } return (