|
1 | | -from ._base_auth import AuthProviderBase, TokenCredential |
| 1 | +from .abc_token_credential import TokenCredential |
2 | 2 | from ..constants import AUTH_MIDDLEWARE_OPTIONS |
3 | | -from ._middleware import BaseMiddleware |
| 3 | +from .middleware import BaseMiddleware |
4 | 4 | from .options.middleware_control import middleware_control |
5 | 5 |
|
6 | 6 |
|
7 | 7 | class AuthorizationHandler(BaseMiddleware): |
8 | | - def __init__(self, auth_provider: AuthProviderBase): |
| 8 | + def __init__(self, credential: TokenCredential, scopes: [str]): |
9 | 9 | super().__init__() |
10 | | - self.auth_provider = auth_provider |
| 10 | + self.credential = credential |
| 11 | + self.scopes = scopes |
11 | 12 | self.retry_count = 0 |
12 | 13 |
|
13 | 14 | def send(self, request, **kwargs): |
14 | | - # Checks if there are any options for this middleware |
15 | | - options = self._get_middleware_options() |
16 | | - # If there is, get the scopes from the options |
17 | | - if options: |
18 | | - self.auth_provider.scopes = options.scopes |
19 | | - |
20 | | - token = self.auth_provider.get_access_token() |
21 | | - request.headers.update({'Authorization': 'Bearer {}'.format(token)}) |
| 15 | + request.headers.update({'Authorization': 'Bearer {}'.format(self._get_access_token())}) |
22 | 16 | response = super().send(request, **kwargs) |
23 | 17 |
|
24 | | - # Token might have expired just before transmission, retry the request |
| 18 | + # Token might have expired just before transmission, retry the request one more time |
25 | 19 | if response.status_code == 401 and self.retry_count < 2: |
26 | 20 | self.retry_count += 1 |
27 | 21 | return self.send(request, **kwargs) |
28 | | - |
29 | 22 | return response |
30 | 23 |
|
31 | | - def _get_middleware_options(self): |
32 | | - return middleware_control.get(AUTH_MIDDLEWARE_OPTIONS) |
| 24 | + def _get_access_token(self): |
| 25 | + return self.credential.get_token(*self.get_scopes())[0] |
33 | 26 |
|
34 | | - |
35 | | -class TokenCredentialAuthProvider(AuthProviderBase): |
36 | | - def __init__(self, credential: TokenCredential, scopes: [str] = ['.default']): |
37 | | - self.credential = credential |
38 | | - self.scopes = scopes |
39 | | - |
40 | | - def get_access_token(self): |
41 | | - return self.credential.get_token(*self.scopes)[0] |
| 27 | + def get_scopes(self): |
| 28 | + # Checks if there are any options for this middleware |
| 29 | + auth_options_present = middleware_control.get(AUTH_MIDDLEWARE_OPTIONS) |
| 30 | + # If there is, get the scopes from the options |
| 31 | + if auth_options_present: |
| 32 | + return auth_options_present.scopes |
| 33 | + else: |
| 34 | + return self.scopes |
0 commit comments