@@ -73,6 +73,22 @@ def setup_api_args(parser: argparse.ArgumentParser):
7373 help = "Exclude the specified API from the server" ,
7474 )
7575
76+ import starlette
77+ def configure_cors_middleware (app : FastAPI , allow_origins : list = ["*" ],
78+ allow_credentials : bool = True ,
79+ allow_methods : list = ["*" ],
80+ allow_headers : list = ["*" ],):
81+ from starlette .middleware .cors import CORSMiddleware
82+
83+ cors_options = {
84+ "allow_methods" : allow_methods ,
85+ "allow_headers" : allow_headers ,
86+ "allow_credentials" : allow_credentials ,
87+ "allow_origins" : allow_origins ,
88+ }
89+
90+ app .user_middleware .insert (0 , starlette .middleware .Middleware (CORSMiddleware , ** cors_options ))
91+ app .build_middleware_stack () # rebuild middleware stack on-the-fly
7692
7793def process_api_args (args : argparse .Namespace , app : FastAPI ):
7894 cors_origin = env .get_and_update_env (args , "cors_origin" , "*" , str )
@@ -84,7 +100,7 @@ def process_api_args(args: argparse.Namespace, app: FastAPI):
84100 config .api = api
85101
86102 if cors_origin :
87- api . set_cors ( allow_origins = [cors_origin ])
103+ configure_cors_middleware ( app , allow_origins = [cors_origin ])
88104 logger .info (f"allow CORS origin: { cors_origin } " )
89105
90106 if not no_playground :
0 commit comments