4
4
# license information.
5
5
# -------------------------------------------------------------------------
6
6
import time
7
- from typing import TYPE_CHECKING , Dict , Optional , TypeVar
8
-
7
+ from typing import TYPE_CHECKING , Optional , TypeVar , MutableMapping
8
+ from azure .core .pipeline import PipelineRequest , PipelineResponse
9
+ from azure .core .pipeline .transport import HttpResponse as LegacyHttpResponse , HttpRequest as LegacyHttpRequest
10
+ from azure .core .rest import HttpResponse , HttpRequest
9
11
from . import HTTPPolicy , SansIOHTTPPolicy
10
12
from ...exceptions import ServiceRequestError
11
13
17
19
AzureKeyCredential ,
18
20
AzureSasCredential ,
19
21
)
20
- from azure .core .pipeline import PipelineRequest , PipelineResponse
21
- from azure .core .pipeline .policies ._universal import HTTPRequestType
22
22
23
- HTTPResponseTypeVar = TypeVar ("HTTPResponseTypeVar" )
23
+ HTTPResponseType = TypeVar ("HTTPResponseType" , HttpResponse , LegacyHttpResponse )
24
+ HTTPRequestType = TypeVar ("HTTPRequestType" , HttpRequest , LegacyHttpRequest )
24
25
25
26
26
27
# pylint:disable=too-few-public-methods
@@ -39,7 +40,7 @@ def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs) -> Non
39
40
self ._token : Optional ["AccessToken" ] = None
40
41
41
42
@staticmethod
42
- def _enforce_https (request : " PipelineRequest" ) -> None :
43
+ def _enforce_https (request : PipelineRequest [ HTTPRequestType ] ) -> None :
43
44
# move 'enforce_https' from options to context so it persists
44
45
# across retries but isn't passed to a transport implementation
45
46
option = request .context .options .pop ("enforce_https" , None )
@@ -55,10 +56,10 @@ def _enforce_https(request: "PipelineRequest") -> None:
55
56
)
56
57
57
58
@staticmethod
58
- def _update_headers (headers : Dict [str , str ], token : str ) -> None :
59
+ def _update_headers (headers : MutableMapping [str , str ], token : str ) -> None :
59
60
"""Updates the Authorization header with the bearer token.
60
61
61
- :param dict headers: The HTTP Request headers
62
+ :param MutableMapping[str, str] headers: The HTTP Request headers
62
63
:param str token: The OAuth token.
63
64
"""
64
65
headers ["Authorization" ] = "Bearer {}" .format (token )
@@ -68,7 +69,7 @@ def _need_new_token(self) -> bool:
68
69
return not self ._token or self ._token .expires_on - time .time () < 300
69
70
70
71
71
- class BearerTokenCredentialPolicy (_BearerTokenCredentialPolicyBase , HTTPPolicy ):
72
+ class BearerTokenCredentialPolicy (_BearerTokenCredentialPolicyBase , HTTPPolicy [ HTTPRequestType , HTTPResponseType ] ):
72
73
"""Adds a bearer token Authorization header to requests.
73
74
74
75
:param credential: The credential.
@@ -77,7 +78,7 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
77
78
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
78
79
"""
79
80
80
- def on_request (self , request : " PipelineRequest" ) -> None :
81
+ def on_request (self , request : PipelineRequest [ HTTPRequestType ] ) -> None :
81
82
"""Called before the policy sends a request.
82
83
83
84
The base implementation authorizes the request with a bearer token.
@@ -90,7 +91,7 @@ def on_request(self, request: "PipelineRequest") -> None:
90
91
self ._token = self ._credential .get_token (* self ._scopes )
91
92
self ._update_headers (request .http_request .headers , self ._token .token )
92
93
93
- def authorize_request (self , request : " PipelineRequest" , * scopes : str , ** kwargs ) -> None :
94
+ def authorize_request (self , request : PipelineRequest [ HTTPRequestType ] , * scopes : str , ** kwargs ) -> None :
94
95
"""Acquire a token from the credential and authorize the request with it.
95
96
96
97
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
@@ -102,7 +103,7 @@ def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs)
102
103
self ._token = self ._credential .get_token (* scopes , ** kwargs )
103
104
self ._update_headers (request .http_request .headers , self ._token .token )
104
105
105
- def send (self , request : " PipelineRequest" ) -> " PipelineResponse" :
106
+ def send (self , request : PipelineRequest [ HTTPRequestType ] ) -> PipelineResponse [ HTTPRequestType , HTTPResponseType ] :
106
107
"""Authorize request with a bearer token and send it to the next policy
107
108
108
109
:param request: The pipeline request object
@@ -136,7 +137,9 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse":
136
137
137
138
return response
138
139
139
- def on_challenge (self , request : "PipelineRequest" , response : "PipelineResponse" ) -> bool :
140
+ def on_challenge (
141
+ self , request : PipelineRequest [HTTPRequestType ], response : PipelineResponse [HTTPRequestType , HTTPResponseType ]
142
+ ) -> bool :
140
143
"""Authorize request according to an authentication challenge
141
144
142
145
This method is called when the resource provider responds 401 with a WWW-Authenticate header.
@@ -149,7 +152,9 @@ def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse")
149
152
# pylint:disable=unused-argument
150
153
return False
151
154
152
- def on_response (self , request : "PipelineRequest" , response : "PipelineResponse" ) -> None :
155
+ def on_response (
156
+ self , request : PipelineRequest [HTTPRequestType ], response : PipelineResponse [HTTPRequestType , HTTPResponseType ]
157
+ ) -> None :
153
158
"""Executed after the request comes back from the next policy.
154
159
155
160
:param request: Request to be modified after returning from the policy.
@@ -158,7 +163,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse")
158
163
:type response: ~azure.core.pipeline.PipelineResponse
159
164
"""
160
165
161
- def on_exception (self , request : " PipelineRequest" ) -> None :
166
+ def on_exception (self , request : PipelineRequest [ HTTPRequestType ] ) -> None :
162
167
"""Executed when an exception is raised while executing the next policy.
163
168
164
169
This method is executed inside the exception handler.
@@ -170,7 +175,7 @@ def on_exception(self, request: "PipelineRequest") -> None:
170
175
return
171
176
172
177
173
- class AzureKeyCredentialPolicy (SansIOHTTPPolicy [" HTTPRequestType" , HTTPResponseTypeVar ]):
178
+ class AzureKeyCredentialPolicy (SansIOHTTPPolicy [HTTPRequestType , HTTPResponseType ]):
174
179
"""Adds a key header for the provided credential.
175
180
176
181
:param credential: The credential used to authenticate requests.
@@ -199,11 +204,11 @@ def __init__(
199
204
self ._name = name
200
205
self ._prefix = prefix + " " if prefix else ""
201
206
202
- def on_request (self , request : " PipelineRequest[HTTPRequestType]" ) -> None :
207
+ def on_request (self , request : PipelineRequest [HTTPRequestType ]) -> None :
203
208
request .http_request .headers [self ._name ] = f"{ self ._prefix } { self ._credential .key } "
204
209
205
210
206
- class AzureSasCredentialPolicy (SansIOHTTPPolicy ):
211
+ class AzureSasCredentialPolicy (SansIOHTTPPolicy [ HTTPRequestType , HTTPResponseType ] ):
207
212
"""Adds a shared access signature to query for the provided credential.
208
213
209
214
:param credential: The credential used to authenticate requests.
@@ -217,7 +222,7 @@ def __init__(self, credential: "AzureSasCredential", **kwargs) -> None: # pylin
217
222
raise ValueError ("credential can not be None" )
218
223
self ._credential = credential
219
224
220
- def on_request (self , request ) :
225
+ def on_request (self , request : PipelineRequest [ HTTPRequestType ]) -> None :
221
226
url = request .http_request .url
222
227
query = request .http_request .query
223
228
signature = self ._credential .signature
0 commit comments