@@ -31,6 +31,7 @@ class OpenApiMiddleware:
3131 openapi_spec_path : str
3232 oidc_config_url : str
3333 private_endpoints : EndpointMethods
34+ public_endpoints : EndpointMethods
3435 default_public : bool
3536 oidc_auth_scheme_name : str = "oidcAuth"
3637
@@ -97,6 +98,23 @@ async def augment_oidc_spec(message: Message):
9798
9899 return await self .app (scope , receive , augment_oidc_spec )
99100
101+ # def augment_spec(self, openapi_spec) -> dict[str, Any]:
102+ # """Augment the OpenAPI spec with auth information."""
103+ # components = openapi_spec.setdefault("components", {})
104+ # securitySchemes = components.setdefault("securitySchemes", {})
105+ # securitySchemes[self.oidc_auth_scheme_name] = {
106+ # "type": "openIdConnect",
107+ # "openIdConnectUrl": self.oidc_config_url,
108+ # }
109+ # for path, method_config in openapi_spec["paths"].items():
110+ # for method, config in method_config.items():
111+ # for private_method in self.private_endpoints.get(path, []):
112+ # if method.casefold() == private_method.casefold():
113+ # config.setdefault("security", []).append(
114+ # {self.oidc_auth_scheme_name: []}
115+ # )
116+ # return openapi_spec
117+
100118 def augment_spec (self , openapi_spec ) -> dict [str , Any ]:
101119 """Augment the OpenAPI spec with auth information."""
102120 components = openapi_spec .setdefault ("components" , {})
@@ -107,9 +125,24 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
107125 }
108126 for path , method_config in openapi_spec ["paths" ].items ():
109127 for method , config in method_config .items ():
110- for private_method in self .private_endpoints .get (path , []):
111- if method .casefold () == private_method .casefold ():
112- config .setdefault ("security" , []).append (
113- {self .oidc_auth_scheme_name : []}
114- )
128+ requires_auth = (
129+ self .path_matches (path , method , self .private_endpoints )
130+ if self .default_public
131+ else not self .path_matches (path , method , self .public_endpoints )
132+ )
133+ if requires_auth :
134+ config .setdefault ("security" , []).append (
135+ {self .oidc_auth_scheme_name : []}
136+ )
115137 return openapi_spec
138+
139+ @staticmethod
140+ def path_matches (path : str , method : str , endpoints : dict [str , list [str ]]) -> bool :
141+ """Check if the given path and method match any of the regex patterns and methods in the endpoints."""
142+ for pattern , endpoint_methods in endpoints .items ():
143+ if not re .match (pattern , path ):
144+ continue
145+ for endpoint_method in endpoint_methods :
146+ if method .casefold () == endpoint_method .casefold ():
147+ return True
148+ return False
0 commit comments