|
| 1 | +from starlette.middleware.base import BaseHTTPMiddleware |
| 2 | +from starlette.requests import Request |
| 3 | +from starlette.responses import Response |
| 4 | +from starlette.types import ASGIApp |
| 5 | + |
| 6 | + |
| 7 | +class CORSMiddleware(BaseHTTPMiddleware): |
| 8 | + def __init__(self, app: ASGIApp, allow_origins=None, allow_methods=None, allow_headers=None): |
| 9 | + super().__init__(app) |
| 10 | + self.allow_origins = allow_origins or ["*"] |
| 11 | + self.allow_methods = allow_methods or ["GET", "POST", "OPTIONS"] |
| 12 | + self.allow_headers = allow_headers or ["*"] |
| 13 | + |
| 14 | + async def dispatch(self, request: Request, call_next): |
| 15 | + # Handle OPTIONS method for CORS preflight requests |
| 16 | + if request.method == "OPTIONS": |
| 17 | + response = Response() |
| 18 | + response.headers["Access-Control-Allow-Origin"] = ",".join(self.allow_origins) |
| 19 | + response.headers["Access-Control-Allow-Methods"] = ",".join(self.allow_methods) |
| 20 | + response.headers["Access-Control-Allow-Headers"] = ",".join(self.allow_headers) |
| 21 | + response.headers["Access-Control-Max-Age"] = "3600" # Cache preflight response for 1 hour |
| 22 | + return response |
| 23 | + |
| 24 | + # Process the request normally and then add CORS headers to the response |
| 25 | + response = await call_next(request) |
| 26 | + response.headers["Access-Control-Allow-Origin"] = ",".join(self.allow_origins) |
| 27 | + response.headers["Access-Control-Allow-Methods"] = ",".join(self.allow_methods) |
| 28 | + response.headers["Access-Control-Allow-Headers"] = ",".join(self.allow_headers) |
| 29 | + return response |
0 commit comments