@@ -33,6 +33,29 @@ async def handle(self, context: "AuthorizationContext"):
3333 """Handles this requirement for a given context."""
3434
3535
36+ class RolesRequirement (Requirement ):
37+ """
38+ Requires an identity with certain roles.
39+ Supports defining sufficient roles (any one is enough).
40+ """
41+
42+ __slots__ = ("_roles" ,)
43+
44+ def __init__ (self , roles : Optional [Sequence [str ]] = None ):
45+ self ._roles = list (roles ) if roles else None
46+
47+ def handle (self , context : "AuthorizationContext" ):
48+ identity = context .identity
49+
50+ if not identity :
51+ context .fail ("Missing identity" )
52+ return
53+
54+ if self ._roles :
55+ if any (identity .has_role (name ) for name in self ._roles ):
56+ context .succeed (self )
57+
58+
3659RequirementConfType = Union [Requirement , Type [Requirement ]]
3760
3861
@@ -208,46 +231,78 @@ def with_default_policy(self, policy: Policy) -> "AuthorizationStrategy":
208231 return self
209232
210233 async def authorize (
211- self , policy_name : Optional [str ], identity : Identity , scope : Any = None
234+ self ,
235+ policy_name : Optional [str ],
236+ identity : Identity ,
237+ scope : Any = None ,
238+ roles : Optional [Sequence [str ]] = None ,
212239 ):
213240 if policy_name :
214241 policy = self .get_policy (policy_name )
215242
216243 if not policy :
217244 raise PolicyNotFoundError (policy_name )
218245
219- await self ._handle_with_policy (policy , identity , scope )
246+ await self ._handle_with_policy (policy , identity , scope , roles )
220247 else :
221248 if self .default_policy :
222- await self ._handle_with_policy (self .default_policy , identity , scope )
249+ await self ._handle_with_policy (
250+ self .default_policy , identity , scope , roles
251+ )
252+ return
253+
254+ if roles :
255+ # This code is only executed if the user specified roles without
256+ # specifying an authorization policy.
257+ await self ._handle_with_roles (identity , roles )
223258 return
224259
225260 if not identity :
226261 raise UnauthorizedError ("Missing identity" , [])
227262 if not identity .is_authenticated ():
228263 raise UnauthorizedError ("The resource requires authentication" , [])
229264
230- def _get_requirements (self , policy : Policy , scope : Any ) -> Iterable [Requirement ]:
265+ def _get_requirements (
266+ self , policy : Policy , scope : Any , roles : Optional [Sequence [str ]] = None
267+ ) -> Iterable [Requirement ]:
268+ if roles :
269+ yield RolesRequirement (roles = roles )
231270 yield from self ._get_instances (policy .requirements , scope )
232271
233- async def _handle_with_policy (self , policy : Policy , identity : Identity , scope : Any ):
272+ async def _handle_with_policy (
273+ self ,
274+ policy : Policy ,
275+ identity : Identity ,
276+ scope : Any ,
277+ roles : Optional [Sequence [str ]] = None ,
278+ ):
234279 with AuthorizationContext (
235- identity , list (self ._get_requirements (policy , scope ))
280+ identity , list (self ._get_requirements (policy , scope , roles ))
236281 ) as context :
237- for requirement in context .requirements :
238- if _is_async_handler (type (requirement )): # type: ignore
239- await requirement .handle (context )
240- else :
241- requirement .handle (context ) # type: ignore
242-
243- if not context .has_succeeded :
244- if identity and identity .is_authenticated ():
245- raise ForbiddenError (
246- context .forced_failure , context .pending_requirements
247- )
248- raise UnauthorizedError (
282+ await self ._handle_context (identity , context )
283+
284+ async def _handle_with_roles (
285+ self , identity : Identity , roles : Optional [Sequence [str ]] = None
286+ ):
287+ # This method is to be used only when the user specified roles without a policy
288+ with AuthorizationContext (identity , [RolesRequirement (roles = roles )]) as context :
289+ await self ._handle_context (identity , context )
290+
291+ async def _handle_context (self , identity : Identity , context : AuthorizationContext ):
292+ for requirement in context .requirements :
293+ if _is_async_handler (type (requirement )): # type: ignore
294+ await requirement .handle (context )
295+ else :
296+ requirement .handle (context ) # type: ignore
297+
298+ if not context .has_succeeded :
299+ if identity and identity .is_authenticated ():
300+ raise ForbiddenError (
249301 context .forced_failure , context .pending_requirements
250302 )
303+ raise UnauthorizedError (
304+ context .forced_failure , context .pending_requirements
305+ )
251306
252307 async def _handle_with_identity_getter (
253308 self , policy_name : Optional [str ], * args , ** kwargs
0 commit comments