Skip to content

Commit 704ddb8

Browse files
committed
turns out the starlette CORS middleware sucks
1 parent df903d6 commit 704ddb8

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

mystbin/backend/app.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
import pathlib
2222
import os
2323
import sys
24-
from typing import Any, Dict, Optional
24+
from typing import Any, Callable, Coroutine, Dict, Optional
2525

2626
import aiohttp
2727
import aioredis
2828
import sentry_sdk
2929
import ujson
30-
from fastapi import FastAPI, Request
31-
from fastapi.middleware.cors import CORSMiddleware
30+
from fastapi import FastAPI, Request, Response
3231
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
3332
from starlette_prometheus import metrics, PrometheusMiddleware
3433

@@ -66,10 +65,20 @@ def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None, config:
6665

6766

6867
app = MystbinApp()
68+
METHODS = ("DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT")
6969

7070

7171
@app.middleware("http")
7272
async def request_stats(request: Request, call_next):
73+
if request.method == "OPTIONS":
74+
raise RuntimeError("blah")
75+
return Response(headers={
76+
"Access-Control-Allowed-Headers": request.headers.get("Access-Control-Request-Headers", ""),
77+
"Access-Control-Allowed-Method": ", ".join(METHODS),
78+
"Access-Control-Allowed-Origin": app.config["site"]["frontend_site"],
79+
"Access-Control-Max-Age": "600",
80+
"Vary": "Origin",
81+
})
7382
request.app.state.request_stats["total"] += 1
7483

7584
if request.url.path != "/admin/stats":
@@ -78,6 +87,22 @@ async def request_stats(request: Request, call_next):
7887
response = await call_next(request)
7988
return response
8089

90+
async def cors_middleware(request: Request, call_next: Callable[[Request], Coroutine[Any, Any, Response]]):
91+
headers={
92+
"Access-Control-Allow-Headers": request.headers.get("Access-Control-Request-Headers", ""),
93+
"Access-Control-Allow-Methods": ", ".join(METHODS),
94+
"Access-Control-Allow-Origin": app.config["site"]["frontend_site"],
95+
"Access-Control-Max-Age": "600",
96+
"Vary": "Origin",
97+
}
98+
99+
if request.method == "OPTIONS":
100+
return Response(headers=headers)
101+
102+
resp = await call_next(request)
103+
resp.headers.update(headers)
104+
return resp
105+
81106

82107
@app.on_event("startup")
83108
async def app_startup():
@@ -98,6 +123,7 @@ async def app_startup():
98123

99124
ratelimits.limiter.startup(app)
100125
app.middleware("http")(ratelimits.limiter.middleware)
126+
app.middleware("http")(cors_middleware)
101127

102128
nocli = pathlib.Path(".nocli")
103129
if nocli.exists():
@@ -114,17 +140,6 @@ async def app_startup():
114140
app.include_router(pastes.router)
115141
app.include_router(user.router)
116142

117-
app.add_middleware(
118-
CORSMiddleware,
119-
allow_origins=[
120-
app.config["site"]["frontend_site"],
121-
app.config["site"]["backend_site"],
122-
],
123-
allow_credentials=True,
124-
allow_methods=["*"],
125-
allow_headers=["*"],
126-
)
127-
128143

129144
try:
130145
sentry_dsn = app.config["sentry"]["dsn"]
@@ -136,5 +151,5 @@ async def app_startup():
136151

137152
app.add_middleware(SentryAsgiMiddleware)
138153

139-
app.add_middleware(PrometheusMiddleware)
140-
app.add_route("/metrics/", metrics)
154+
#app.add_middleware(PrometheusMiddleware)
155+
#app.add_route("/metrics/", metrics)

mystbin/backend/utils/ratelimits.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,19 @@ async def get_key(self, zone: str, request: Request) -> str:
152152

153153
async def _transform_and_log(self, request: Request, response: Response) -> Response:
154154
if isinstance(response, StreamingResponse):
155-
resp_ = Response(status_code=response.status_code, background=response.background, media_type=cast(str, response.media_type))
156-
resp_._headers = response._headers
157-
body = b""
158-
async for chunk in response.body_iterator:
159-
if not isinstance(chunk, bytes):
160-
chunk = chunk.encode(response.charset)
161-
162-
body += chunk
155+
resp_ = Response(status_code=response.status_code, background=response.background, media_type=cast(str, response.media_type))
156+
del response._headers['content-type']
157+
del response._headers['content-length']
158+
resp_._headers = response._headers
159+
body = b""
160+
async for chunk in response.body_iterator:
161+
if not isinstance(chunk, bytes):
162+
chunk = chunk.encode(response.charset)
163163

164-
resp_.body = body
165-
response = resp_
164+
body += chunk
165+
166+
resp_.body = body
167+
response = resp_
166168

167169
await self.app.state.db.put_log(request, response)
168170
return response
@@ -226,7 +228,8 @@ async def middleware(self, request: Request, call_next: _CT) -> Response:
226228

227229
resp = await call_next(request)
228230
resp.headers.update(headers)
229-
return await self._transform_and_log(request, resp)
231+
resp = await self._transform_and_log(request, resp)
232+
return resp
230233

231234

232235
else:

0 commit comments

Comments
 (0)