Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,10 @@ SENTRY_DSN=
# Configure these with your own Docker registry images
DOCKER_IMAGE_BACKEND=backend
DOCKER_IMAGE_FRONTEND=frontend

# Redis config
REDIS_URL=redis://redis:6379/0

# Rate Limiting config
RATE_LIMITER_STRATEGY=sliding_window
RATE_LIMIT_FAIL_OPEN=false
34 changes: 34 additions & 0 deletions backend/app/alembic/rate_limiting_algorithms/sliding_window.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
-- KEYS[1] = key
-- ARGV[1] = now_ms
-- ARGV[2] = window_ms
-- ARGV[3] = limit
-- ARGV[4] = member

local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local member = ARGV[4]

local min_score = now - window

-- remove old entries
redis.call('ZREMRANGEBYSCORE', key, 0, min_score)

-- add current
redis.call('ZADD', key, now, member)

-- count
local cnt = redis.call('ZCARD', key)

-- expire same as window
redis.call('PEXPIRE', key, window)

-- fetch oldest
local earliest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local oldest_ts = 0
if earliest ~= false and earliest ~= nil and #earliest >= 2 then
oldest_ts = earliest[2]
end

return {cnt, oldest_ts}
9 changes: 8 additions & 1 deletion backend/app/api/routes/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
get_current_active_superuser,
)
from app.core.config import settings
from app.core.rate_limiter.key_strategy.key_strategy_enum import KeyStrategyName
from app.core.rate_limiter.rate_limiter import RateLimiter
from app.core.security import get_password_hash, verify_password
from app.models import (
Item,
Expand All @@ -31,7 +33,12 @@

@router.get(
"/",
dependencies=[Depends(get_current_active_superuser)],
dependencies=[
Depends(
RateLimiter(limit=10, window_seconds=60, key_policy=KeyStrategyName.IP)
),
Depends(get_current_active_superuser),
],
response_model=UsersPublic,
)
def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> Any:
Expand Down
17 changes: 17 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,30 @@ class Settings(BaseSettings):
list[AnyUrl] | str, BeforeValidator(parse_cors)
] = []

# Correct Redis default inside Docker Compose
REDIS_URL: str = ""
RATE_LIMITER_STRATEGY: Literal["none", "sliding_window"] = "none"
RATE_LIMIT_FAIL_OPEN: bool = True

@computed_field # type: ignore[prop-decorator]
@property
def all_cors_origins(self) -> list[str]:
return [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + [
self.FRONTEND_HOST
]

@computed_field # type: ignore[prop-decorator]
@property
def rate_limit_enabled(self) -> bool:
"""
Returns True if rate limiting should be enabled based on strategy.
Mirrors the style of all_cors_origins.
"""
strategy = (self.RATE_LIMITER_STRATEGY or "").strip().lower()
redis_url = (self.REDIS_URL or "").strip()

return strategy not in ("", "none") and bool(redis_url)

PROJECT_NAME: str
SENTRY_DSN: HttpUrl | None = None
POSTGRES_SERVER: str
Expand Down
Empty file.
Empty file.
12 changes: 12 additions & 0 deletions backend/app/core/rate_limiter/key_strategy/header_key_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from starlette.requests import Request

from app.core.rate_limiter.key_strategy.key_strategy import KeyStrategy


class HeaderKeyStrategy(KeyStrategy):
def __init__(self, header_name: str = "X-Client-ID"):
self.header_name = header_name

def get_key(self, request: Request, route_path: str) -> str:
value = request.headers.get(self.header_name, "unknown")
return f"header:{self.header_name}:{value}:{route_path}"
11 changes: 11 additions & 0 deletions backend/app/core/rate_limiter/key_strategy/ip_key_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from starlette.requests import Request

from app.core.rate_limiter.key_strategy.key_strategy import KeyStrategy


class IPKeyStrategy(KeyStrategy):
"""Generate rate limit key based on client IP address."""

def get_key(self, request: Request, route_path: str) -> str:
client_ip = request.client.host if request.client else "unknown"
return f"ip:{client_ip}:{route_path}"
12 changes: 12 additions & 0 deletions backend/app/core/rate_limiter/key_strategy/key_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod

from starlette.requests import Request


class KeyStrategy(ABC):
"""Base interface for rate-limit key generation."""

@abstractmethod
def get_key(self, request: Request, route_path: str) -> str:
"""Return unique identifier string (e.g., 'ip:127.0.0.1')"""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class KeyStrategyName(str, Enum):
IP = "ip"
HEADER = "header"
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from app.core.rate_limiter.key_strategy.header_key_strategy import HeaderKeyStrategy
from app.core.rate_limiter.key_strategy.ip_key_strategy import IPKeyStrategy
from app.core.rate_limiter.key_strategy.key_strategy import KeyStrategy
from app.core.rate_limiter.key_strategy.key_strategy_enum import KeyStrategyName


def get_key_strategy(
name: KeyStrategyName, header_name: str | None = None
) -> KeyStrategy:
if name == KeyStrategyName.IP:
return IPKeyStrategy()

if name == KeyStrategyName.HEADER:
return HeaderKeyStrategy(header_name=header_name or "X-Client-ID")

raise ValueError(f"Unsupported key strategy: {name}")
52 changes: 52 additions & 0 deletions backend/app/core/rate_limiter/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging

from fastapi import HTTPException, Request

from app.core.rate_limiter.key_strategy import key_strategy_registry
from app.core.rate_limiter.key_strategy.key_strategy_enum import KeyStrategyName

logger = logging.getLogger(__name__)


class RateLimiter:
def __init__(
self,
limit: int,
window_seconds: int,
key_policy: KeyStrategyName = KeyStrategyName.IP,
):
self.limit = limit
self.window_seconds = window_seconds
self.key_policy = key_policy

async def __call__(self, request: Request) -> None:
rate_limiter = getattr(request.app.state, "rate_limiter", None)

if rate_limiter is None:
return None

# Create Key
key_strategy = key_strategy_registry.get_key_strategy(self.key_policy)
path: str = request.scope.get("path") or ""
key = key_strategy.get_key(request, path)

allowed = True
retry_after = None
try:
allowed, retry_after = await rate_limiter.allow_request(
key, self.limit, self.window_seconds
)
except Exception:
logger.exception("Error invoking rate limiter")
if rate_limiter.get_fail_open():
raise HTTPException(
status_code=503,
detail={"detail": "Rate limiter unavailable"},
)

if not allowed:
raise HTTPException(
status_code=429,
detail=f"Too Many Requests. Retry after {retry_after}s",
headers={"Retry-After": str(retry_after)},
)
Empty file.
18 changes: 18 additions & 0 deletions backend/app/core/rate_limiter/rate_limiting_algorithm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from abc import ABC, abstractmethod


class BaseRateLimiter(ABC):
"""Interface for pluggable rate limiter strategies."""

@abstractmethod
async def allow_request(
self, key: str, limit: int, window_seconds: int, member_id: str | None = None
) -> tuple[bool, int | None]:
"""
Return (allowed: bool, retry_after_seconds: Optional[int]).
If allowed True -> retry_after_seconds is None.
If allowed False -> retry_after_seconds is seconds until next allowed request.
"""
raise NotImplementedError
27 changes: 27 additions & 0 deletions backend/app/core/rate_limiter/rate_limiting_algorithm/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import redis.asyncio as redis

from app.core.rate_limiter.rate_limiting_algorithm.base import BaseRateLimiter
from app.core.rate_limiter.rate_limiting_algorithm.sliding_window import (
SlidingWindowRateLimiter,
)


def get_rate_limiter(
strategy: str | None, redis_url: str | None, fail_open: bool | None
) -> BaseRateLimiter | None:
"""
Factory: returns an instance of BaseRateLimiter or None (if disabled).
"""
if not strategy or strategy.lower() in ("none", "null", ""):
return None

if not redis_url:
return None

rc: redis.Redis = redis.from_url(redis_url, encoding="utf-8", decode_responses=True) # type: ignore[no-untyped-call]
st = strategy.lower()
if st == "sliding_window" or st == "sliding-window":
return SlidingWindowRateLimiter(rc, fail_open or False)

# extendable for other strategies
raise ValueError(f"Unknown rate limiter strategy: {strategy}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import logging
import time
from pathlib import Path

import redis.asyncio as redis

from app.core.rate_limiter.rate_limiting_algorithm.base import BaseRateLimiter

logger = logging.getLogger(__name__)

SCRIPT_PATH = (
Path(__file__).resolve().parents[3]
/ "alembic"
/ "rate_limiting_algorithms"
/ "sliding_window.lua"
)


class SlidingWindowRateLimiter(BaseRateLimiter):
def __init__(self, redis_client: redis.Redis, fail_open: bool):
self.redis = redis_client
self.lua_script = None
self.fail_open = fail_open

async def load_script(self) -> str | None:
if self.lua_script is None:
script_text = SCRIPT_PATH.read_text()
# LOAD script into redis → returns SHA
self.lua_script = await self.redis.script_load(script_text)
return self.lua_script

async def allow_request(
self, key: str, limit: int, window_seconds: int, member_id: str | None = None
) -> tuple[bool, int | None]:
now_ms = int(time.time() * 1000)
window_ms = window_seconds * 1000
member = member_id or f"{now_ms}"

try:
sha = await self.load_script()
if sha is None:
raise Exception
res = await self.redis.evalsha( # type: ignore[misc]
sha, 1, key, now_ms, window_ms, limit, member
)
except Exception:
logger.exception("Redis error; failing open")
return True, None

cnt, oldest_ts = int(res[0]), int(res[1] or 0)

if cnt <= limit:
return True, None

retry_after_ms = (oldest_ts + window_ms) - now_ms
retry_after_s = max(0, retry_after_ms // 1000)

return False, retry_after_s

def get_fail_open(self) -> bool:
return self.fail_open
10 changes: 10 additions & 0 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from app.api.main import api_router
from app.core.config import settings
from app.core.rate_limiter.rate_limiting_algorithm.registry import get_rate_limiter


def custom_generate_unique_id(route: APIRoute) -> str:
Expand All @@ -30,4 +31,13 @@ def custom_generate_unique_id(route: APIRoute) -> str:
allow_headers=["*"],
)

# Set rate Limiting
if settings.rate_limit_enabled:
rate_limiter = get_rate_limiter(
settings.RATE_LIMITER_STRATEGY,
settings.REDIS_URL,
settings.RATE_LIMIT_FAIL_OPEN,
)
app.state.rate_limiter = rate_limiter

app.include_router(api_router, prefix=settings.API_V1_STR)
4 changes: 4 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ dependencies = [
"pydantic-settings<3.0.0,>=2.2.1",
"sentry-sdk[fastapi]<2.0.0,>=1.40.6",
"pyjwt<3.0.0,>=2.8.0",

"redis>=4.6.0",
]

[tool.uv]
Expand All @@ -31,6 +33,8 @@ dev-dependencies = [
"pre-commit<4.0.0,>=3.6.2",
"types-passlib<2.0.0.0,>=1.7.7.20240106",
"coverage<8.0.0,>=7.4.3",

"pytest-asyncio",
]

[build-system]
Expand Down
Loading
Loading