|
| 1 | +# ------------------------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. See LICENSE.txt in the project root for |
| 4 | +# license information. |
| 5 | +# ------------------------------------------------------------------------- |
| 6 | +import asyncio |
| 7 | +import time |
| 8 | + |
| 9 | +from typing import Any, Awaitable, Optional, Dict, Union |
| 10 | +from azure.core.pipeline.policies import AsyncHTTPPolicy |
| 11 | +from azure.core.credentials import AccessToken |
| 12 | +from azure.core.pipeline import PipelineRequest, PipelineResponse |
| 13 | +from azure.cosmos import http_constants |
| 14 | + |
| 15 | + |
| 16 | +async def await_result(func, *args, **kwargs): |
| 17 | + """If func returns an awaitable, await it.""" |
| 18 | + result = func(*args, **kwargs) |
| 19 | + if hasattr(result, "__await__"): |
| 20 | + # type ignore on await: https://github.com/python/mypy/issues/7587 |
| 21 | + return await result # type: ignore |
| 22 | + return result |
| 23 | + |
| 24 | + |
| 25 | +class _AsyncCosmosBearerTokenCredentialPolicyBase(object): |
| 26 | + """Base class for a Bearer Token Credential Policy. |
| 27 | +
|
| 28 | + :param credential: The credential. |
| 29 | + :type credential: ~azure.core.credentials.TokenCredential |
| 30 | + :param str scopes: Lets you specify the type of access needed. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument |
| 34 | + # type: (TokenCredential, *str, **Any) -> None |
| 35 | + super(_AsyncCosmosBearerTokenCredentialPolicyBase, self).__init__() |
| 36 | + self._scopes = scopes |
| 37 | + self._credential = credential |
| 38 | + self._token = None # type: Optional[AccessToken] |
| 39 | + self._lock = asyncio.Lock() |
| 40 | + |
| 41 | + @staticmethod |
| 42 | + def _enforce_https(request): |
| 43 | + # type: (PipelineRequest) -> None |
| 44 | + |
| 45 | + # move 'enforce_https' from options to context so it persists |
| 46 | + # across retries but isn't passed to a transport implementation |
| 47 | + option = request.context.options.pop("enforce_https", None) |
| 48 | + |
| 49 | + # True is the default setting; we needn't preserve an explicit opt in to the default behavior |
| 50 | + if option is False: |
| 51 | + request.context["enforce_https"] = option |
| 52 | + |
| 53 | + enforce_https = request.context.get("enforce_https", True) |
| 54 | + if enforce_https and not request.http_request.url.lower().startswith("https"): |
| 55 | + raise ValueError( |
| 56 | + "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." |
| 57 | + ) |
| 58 | + |
| 59 | + @staticmethod |
| 60 | + def _update_headers(headers, token): |
| 61 | + # type: (Dict[str, str], str) -> None |
| 62 | + """Updates the Authorization header with the cosmos signature and bearer token. |
| 63 | + This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works |
| 64 | + to properly sign the authorization header for Cosmos' REST API. For more information: |
| 65 | + https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header |
| 66 | +
|
| 67 | + :param dict headers: The HTTP Request headers |
| 68 | + :param str token: The OAuth token. |
| 69 | + """ |
| 70 | + headers[http_constants.HttpHeaders.Authorization] = "type=aad&ver=1.0&sig={}".format(token) |
| 71 | + |
| 72 | + @property |
| 73 | + def _need_new_token(self) -> bool: |
| 74 | + return not self._token or self._token.expires_on - time.time() < 300 |
| 75 | + |
| 76 | + |
| 77 | +class AsyncCosmosBearerTokenCredentialPolicy(_AsyncCosmosBearerTokenCredentialPolicyBase, AsyncHTTPPolicy): |
| 78 | + """Adds a bearer token Authorization header to requests. |
| 79 | +
|
| 80 | + :param credential: The credential. |
| 81 | + :type credential: ~azure.core.TokenCredential |
| 82 | + :param str scopes: Lets you specify the type of access needed. |
| 83 | + :raises ValueError: If https_enforce does not match with endpoint being used. |
| 84 | + """ |
| 85 | + |
| 86 | + async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method |
| 87 | + """Adds a bearer token Authorization header to request and sends request to next policy. |
| 88 | +
|
| 89 | + :param request: The pipeline request object to be modified. |
| 90 | + :type request: ~azure.core.pipeline.PipelineRequest |
| 91 | + :raises: :class:`~azure.core.exceptions.ServiceRequestError` |
| 92 | + """ |
| 93 | + self._enforce_https(request) # pylint:disable=protected-access |
| 94 | + |
| 95 | + if self._token is None or self._need_new_token: |
| 96 | + async with self._lock: |
| 97 | + # double check because another coroutine may have acquired a token while we waited to acquire the lock |
| 98 | + if self._token is None or self._need_new_token: |
| 99 | + self._token = await self._credential.get_token(*self._scopes) |
| 100 | + self._update_headers(request.http_request.headers, self._token.token) |
| 101 | + |
| 102 | + async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None: |
| 103 | + """Acquire a token from the credential and authorize the request with it. |
| 104 | +
|
| 105 | + Keyword arguments are passed to the credential's get_token method. The token will be cached and used to |
| 106 | + authorize future requests. |
| 107 | +
|
| 108 | + :param ~azure.core.pipeline.PipelineRequest request: the request |
| 109 | + :param str scopes: required scopes of authentication |
| 110 | + """ |
| 111 | + async with self._lock: |
| 112 | + self._token = await self._credential.get_token(*scopes, **kwargs) |
| 113 | + self._update_headers(request.http_request.headers, self._token.token) |
| 114 | + |
| 115 | + async def send(self, request: "PipelineRequest") -> "PipelineResponse": |
| 116 | + """Authorize request with a bearer token and send it to the next policy |
| 117 | +
|
| 118 | + :param request: The pipeline request object |
| 119 | + :type request: ~azure.core.pipeline.PipelineRequest |
| 120 | + """ |
| 121 | + await await_result(self.on_request, request) |
| 122 | + try: |
| 123 | + response = await self.next.send(request) |
| 124 | + await await_result(self.on_response, request, response) |
| 125 | + except Exception: # pylint:disable=broad-except |
| 126 | + handled = await await_result(self.on_exception, request) |
| 127 | + if not handled: |
| 128 | + raise |
| 129 | + else: |
| 130 | + if response.http_response.status_code == 401: |
| 131 | + self._token = None # any cached token is invalid |
| 132 | + if "WWW-Authenticate" in response.http_response.headers: |
| 133 | + request_authorized = await self.on_challenge(request, response) |
| 134 | + if request_authorized: |
| 135 | + try: |
| 136 | + response = await self.next.send(request) |
| 137 | + await await_result(self.on_response, request, response) |
| 138 | + except Exception: # pylint:disable=broad-except |
| 139 | + handled = await await_result(self.on_exception, request) |
| 140 | + if not handled: |
| 141 | + raise |
| 142 | + |
| 143 | + return response |
| 144 | + |
| 145 | + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: |
| 146 | + """Authorize request according to an authentication challenge |
| 147 | +
|
| 148 | + This method is called when the resource provider responds 401 with a WWW-Authenticate header. |
| 149 | +
|
| 150 | + :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge |
| 151 | + :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response |
| 152 | + :returns: a bool indicating whether the policy should send the request |
| 153 | + """ |
| 154 | + # pylint:disable=unused-argument,no-self-use |
| 155 | + return False |
| 156 | + |
| 157 | + def on_response(self, request: PipelineRequest, response: PipelineResponse) -> Union[None, Awaitable[None]]: |
| 158 | + """Executed after the request comes back from the next policy. |
| 159 | +
|
| 160 | + :param request: Request to be modified after returning from the policy. |
| 161 | + :type request: ~azure.core.pipeline.PipelineRequest |
| 162 | + :param response: Pipeline response object |
| 163 | + :type response: ~azure.core.pipeline.PipelineResponse |
| 164 | + """ |
| 165 | + |
| 166 | + def on_exception(self, request: PipelineRequest) -> None: |
| 167 | + """Executed when an exception is raised while executing the next policy. |
| 168 | +
|
| 169 | + This method is executed inside the exception handler. |
| 170 | +
|
| 171 | + :param request: The Pipeline request object |
| 172 | + :type request: ~azure.core.pipeline.PipelineRequest |
| 173 | + """ |
| 174 | + # pylint: disable=no-self-use,unused-argument |
| 175 | + return |
0 commit comments