2121import pathlib
2222import os
2323import sys
24- from typing import Any , Dict , Optional
24+ from typing import Any , Callable , Coroutine , Dict , Optional
2525
2626import aiohttp
2727import aioredis
2828import sentry_sdk
2929import ujson
30- from fastapi import FastAPI , Request
31- from fastapi .middleware .cors import CORSMiddleware
30+ from fastapi import FastAPI , Request , Response
3231from sentry_sdk .integrations .asgi import SentryAsgiMiddleware
3332from starlette_prometheus import metrics , PrometheusMiddleware
3433
@@ -66,10 +65,20 @@ def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None, config:
6665
6766
6867app = MystbinApp ()
68+ METHODS = ("DELETE" , "GET" , "OPTIONS" , "PATCH" , "POST" , "PUT" )
6969
7070
7171@app .middleware ("http" )
7272async def request_stats (request : Request , call_next ):
73+ if request .method == "OPTIONS" :
74+ raise RuntimeError ("blah" )
75+ return Response (headers = {
76+ "Access-Control-Allowed-Headers" : request .headers .get ("Access-Control-Request-Headers" , "" ),
77+ "Access-Control-Allowed-Method" : ", " .join (METHODS ),
78+ "Access-Control-Allowed-Origin" : app .config ["site" ]["frontend_site" ],
79+ "Access-Control-Max-Age" : "600" ,
80+ "Vary" : "Origin" ,
81+ })
7382 request .app .state .request_stats ["total" ] += 1
7483
7584 if request .url .path != "/admin/stats" :
@@ -78,6 +87,22 @@ async def request_stats(request: Request, call_next):
7887 response = await call_next (request )
7988 return response
8089
90+ async def cors_middleware (request : Request , call_next : Callable [[Request ], Coroutine [Any , Any , Response ]]):
91+ headers = {
92+ "Access-Control-Allow-Headers" : request .headers .get ("Access-Control-Request-Headers" , "" ),
93+ "Access-Control-Allow-Methods" : ", " .join (METHODS ),
94+ "Access-Control-Allow-Origin" : app .config ["site" ]["frontend_site" ],
95+ "Access-Control-Max-Age" : "600" ,
96+ "Vary" : "Origin" ,
97+ }
98+
99+ if request .method == "OPTIONS" :
100+ return Response (headers = headers )
101+
102+ resp = await call_next (request )
103+ resp .headers .update (headers )
104+ return resp
105+
81106
82107@app .on_event ("startup" )
83108async def app_startup ():
@@ -98,6 +123,7 @@ async def app_startup():
98123
99124 ratelimits .limiter .startup (app )
100125 app .middleware ("http" )(ratelimits .limiter .middleware )
126+ app .middleware ("http" )(cors_middleware )
101127
102128 nocli = pathlib .Path (".nocli" )
103129 if nocli .exists ():
@@ -114,17 +140,6 @@ async def app_startup():
114140app .include_router (pastes .router )
115141app .include_router (user .router )
116142
117- app .add_middleware (
118- CORSMiddleware ,
119- allow_origins = [
120- app .config ["site" ]["frontend_site" ],
121- app .config ["site" ]["backend_site" ],
122- ],
123- allow_credentials = True ,
124- allow_methods = ["*" ],
125- allow_headers = ["*" ],
126- )
127-
128143
129144try :
130145 sentry_dsn = app .config ["sentry" ]["dsn" ]
@@ -136,5 +151,5 @@ async def app_startup():
136151
137152 app .add_middleware (SentryAsgiMiddleware )
138153
139- app .add_middleware (PrometheusMiddleware )
140- app .add_route ("/metrics/" , metrics )
154+ # app.add_middleware(PrometheusMiddleware)
155+ # app.add_route("/metrics/", metrics)
0 commit comments