Skip to content

Commit 9b04f09

Browse files
xiangyan99lmazuel
andauthored
typing improvement for policies (Azure#31018)
* typing improvement for policies * update * update * update * update * update * update * update * update * update * update * update * Update sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py Co-authored-by: Laurent Mazuel <[email protected]> * address review feedback * update * update * update * update * update * update * update * update * update * update * Update sdk/core/azure-core/azure/core/pipeline/__init__.py * update * update * update * update * Update __init__.py * update * update * update --------- Co-authored-by: Laurent Mazuel <[email protected]>
1 parent 1ddf363 commit 9b04f09

File tree

15 files changed

+166
-117
lines changed

15 files changed

+166
-117
lines changed

sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
import time
7-
from typing import TYPE_CHECKING, Dict, Optional, TypeVar
8-
7+
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping
8+
from azure.core.pipeline import PipelineRequest, PipelineResponse
9+
from azure.core.pipeline.transport import HttpResponse as LegacyHttpResponse, HttpRequest as LegacyHttpRequest
10+
from azure.core.rest import HttpResponse, HttpRequest
911
from . import HTTPPolicy, SansIOHTTPPolicy
1012
from ...exceptions import ServiceRequestError
1113

@@ -17,10 +19,9 @@
1719
AzureKeyCredential,
1820
AzureSasCredential,
1921
)
20-
from azure.core.pipeline import PipelineRequest, PipelineResponse
21-
from azure.core.pipeline.policies._universal import HTTPRequestType
2222

23-
HTTPResponseTypeVar = TypeVar("HTTPResponseTypeVar")
23+
HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
24+
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
2425

2526

2627
# pylint:disable=too-few-public-methods
@@ -39,7 +40,7 @@ def __init__(self, credential: "TokenCredential", *scopes: str, **kwargs) -> Non
3940
self._token: Optional["AccessToken"] = None
4041

4142
@staticmethod
42-
def _enforce_https(request: "PipelineRequest") -> None:
43+
def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
4344
# move 'enforce_https' from options to context so it persists
4445
# across retries but isn't passed to a transport implementation
4546
option = request.context.options.pop("enforce_https", None)
@@ -55,10 +56,10 @@ def _enforce_https(request: "PipelineRequest") -> None:
5556
)
5657

5758
@staticmethod
58-
def _update_headers(headers: Dict[str, str], token: str) -> None:
59+
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
5960
"""Updates the Authorization header with the bearer token.
6061
61-
:param dict headers: The HTTP Request headers
62+
:param MutableMapping[str, str] headers: The HTTP Request headers
6263
:param str token: The OAuth token.
6364
"""
6465
headers["Authorization"] = "Bearer {}".format(token)
@@ -68,7 +69,7 @@ def _need_new_token(self) -> bool:
6869
return not self._token or self._token.expires_on - time.time() < 300
6970

7071

71-
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
72+
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
7273
"""Adds a bearer token Authorization header to requests.
7374
7475
:param credential: The credential.
@@ -77,7 +78,7 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy):
7778
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
7879
"""
7980

80-
def on_request(self, request: "PipelineRequest") -> None:
81+
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
8182
"""Called before the policy sends a request.
8283
8384
The base implementation authorizes the request with a bearer token.
@@ -90,7 +91,7 @@ def on_request(self, request: "PipelineRequest") -> None:
9091
self._token = self._credential.get_token(*self._scopes)
9192
self._update_headers(request.http_request.headers, self._token.token)
9293

93-
def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs) -> None:
94+
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs) -> None:
9495
"""Acquire a token from the credential and authorize the request with it.
9596
9697
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
@@ -102,7 +103,7 @@ def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs)
102103
self._token = self._credential.get_token(*scopes, **kwargs)
103104
self._update_headers(request.http_request.headers, self._token.token)
104105

105-
def send(self, request: "PipelineRequest") -> "PipelineResponse":
106+
def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
106107
"""Authorize request with a bearer token and send it to the next policy
107108
108109
:param request: The pipeline request object
@@ -136,7 +137,9 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse":
136137

137138
return response
138139

139-
def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
140+
def on_challenge(
141+
self, request: PipelineRequest[HTTPRequestType], response: PipelineResponse[HTTPRequestType, HTTPResponseType]
142+
) -> bool:
140143
"""Authorize request according to an authentication challenge
141144
142145
This method is called when the resource provider responds 401 with a WWW-Authenticate header.
@@ -149,7 +152,9 @@ def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse")
149152
# pylint:disable=unused-argument
150153
return False
151154

152-
def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
155+
def on_response(
156+
self, request: PipelineRequest[HTTPRequestType], response: PipelineResponse[HTTPRequestType, HTTPResponseType]
157+
) -> None:
153158
"""Executed after the request comes back from the next policy.
154159
155160
:param request: Request to be modified after returning from the policy.
@@ -158,7 +163,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse")
158163
:type response: ~azure.core.pipeline.PipelineResponse
159164
"""
160165

161-
def on_exception(self, request: "PipelineRequest") -> None:
166+
def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
162167
"""Executed when an exception is raised while executing the next policy.
163168
164169
This method is executed inside the exception handler.
@@ -170,7 +175,7 @@ def on_exception(self, request: "PipelineRequest") -> None:
170175
return
171176

172177

173-
class AzureKeyCredentialPolicy(SansIOHTTPPolicy["HTTPRequestType", HTTPResponseTypeVar]):
178+
class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
174179
"""Adds a key header for the provided credential.
175180
176181
:param credential: The credential used to authenticate requests.
@@ -199,11 +204,11 @@ def __init__(
199204
self._name = name
200205
self._prefix = prefix + " " if prefix else ""
201206

202-
def on_request(self, request: "PipelineRequest[HTTPRequestType]") -> None:
207+
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
203208
request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
204209

205210

206-
class AzureSasCredentialPolicy(SansIOHTTPPolicy):
211+
class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
207212
"""Adds a shared access signature to query for the provided credential.
208213
209214
:param credential: The credential used to authenticate requests.
@@ -217,7 +222,7 @@ def __init__(self, credential: "AzureSasCredential", **kwargs) -> None: # pylin
217222
raise ValueError("credential can not be None")
218223
self._credential = credential
219224

220-
def on_request(self, request):
225+
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
221226
url = request.http_request.url
222227
query = request.http_request.query
223228
signature = self._credential.signature

sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,27 @@
55
# -------------------------------------------------------------------------
66
import asyncio
77
import time
8-
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast
8+
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar
99

1010
from azure.core.credentials import AccessToken
11+
from azure.core.pipeline import PipelineRequest, PipelineResponse
1112
from azure.core.pipeline.policies import AsyncHTTPPolicy
1213
from azure.core.pipeline.policies._authentication import (
1314
_BearerTokenCredentialPolicyBase,
1415
)
16+
from azure.core.pipeline.transport import AsyncHttpResponse as LegacyAsyncHttpResponse, HttpRequest as LegacyHttpRequest
17+
from azure.core.rest import AsyncHttpResponse, HttpRequest
1518

1619
from .._tools_async import await_result
1720

1821
if TYPE_CHECKING:
1922
from azure.core.credentials_async import AsyncTokenCredential
20-
from azure.core.pipeline import PipelineRequest, PipelineResponse
2123

24+
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", AsyncHttpResponse, LegacyAsyncHttpResponse)
25+
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
2226

23-
class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy):
27+
28+
class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
2429
"""Adds a bearer token Authorization header to requests.
2530
2631
:param credential: The credential.
@@ -36,7 +41,7 @@ def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: A
3641
self._scopes = scopes
3742
self._token: Optional["AccessToken"] = None
3843

39-
async def on_request(self, request: "PipelineRequest") -> None:
44+
async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
4045
"""Adds a bearer token Authorization header to request and sends request to next policy.
4146
4247
:param request: The pipeline request object to be modified.
@@ -52,7 +57,7 @@ async def on_request(self, request: "PipelineRequest") -> None:
5257
self._token = await await_result(self._credential.get_token, *self._scopes)
5358
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
5459

55-
async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: Any) -> None:
60+
async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
5661
"""Acquire a token from the credential and authorize the request with it.
5762
5863
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
@@ -65,7 +70,9 @@ async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kw
6570
self._token = await await_result(self._credential.get_token, *scopes, **kwargs)
6671
request.http_request.headers["Authorization"] = "Bearer " + cast(AccessToken, self._token).token
6772

68-
async def send(self, request: "PipelineRequest") -> "PipelineResponse":
73+
async def send(
74+
self, request: PipelineRequest[HTTPRequestType]
75+
) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
6976
"""Authorize request with a bearer token and send it to the next policy
7077
7178
:param request: The pipeline request object
@@ -101,7 +108,11 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse":
101108

102109
return response
103110

104-
async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
111+
async def on_challenge(
112+
self,
113+
request: PipelineRequest[HTTPRequestType],
114+
response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
115+
) -> bool:
105116
"""Authorize request according to an authentication challenge
106117
107118
This method is called when the resource provider responds 401 with a WWW-Authenticate header.
@@ -114,7 +125,11 @@ async def on_challenge(self, request: "PipelineRequest", response: "PipelineResp
114125
# pylint:disable=unused-argument
115126
return False
116127

117-
def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> Optional[Awaitable[None]]:
128+
def on_response(
129+
self,
130+
request: PipelineRequest[HTTPRequestType],
131+
response: PipelineResponse[HTTPRequestType, AsyncHTTPResponseType],
132+
) -> Optional[Awaitable[None]]:
118133
"""Executed after the request comes back from the next policy.
119134
120135
:param request: Request to be modified after returning from the policy.
@@ -123,7 +138,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse")
123138
:type response: ~azure.core.pipeline.PipelineResponse
124139
"""
125140

126-
def on_exception(self, request: "PipelineRequest") -> None:
141+
def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
127142
"""Executed when an exception is raised while executing the next policy.
128143
129144
This method is executed inside the exception handler.

sdk/core/azure-core/azure/core/pipeline/policies/_base.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import logging
3030

3131
from typing import (
32-
TYPE_CHECKING,
3332
Generic,
3433
TypeVar,
3534
Union,
@@ -41,29 +40,23 @@
4140

4241
from azure.core.pipeline import PipelineRequest, PipelineResponse
4342

44-
if TYPE_CHECKING:
45-
from azure.core.pipeline.transport import HttpTransport
46-
47-
48-
HTTPResponseTypeVar = TypeVar("HTTPResponseTypeVar")
49-
HTTPRequestTypeVar = TypeVar("HTTPRequestTypeVar")
43+
HTTPResponseType = TypeVar("HTTPResponseType")
44+
HTTPRequestType = TypeVar("HTTPRequestType")
5045

5146
_LOGGER = logging.getLogger(__name__)
5247

5348

54-
class HTTPPolicy(abc.ABC, Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]):
49+
class HTTPPolicy(abc.ABC, Generic[HTTPRequestType, HTTPResponseType]):
5550
"""An HTTP policy ABC.
5651
5752
Use with a synchronous pipeline.
5853
"""
5954

60-
next: "HTTPPolicy[HTTPRequestTypeVar, HTTPResponseTypeVar]"
55+
next: "HTTPPolicy[HTTPRequestType, HTTPResponseType]"
6156
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""
6257

6358
@abc.abstractmethod
64-
def send(
65-
self, request: PipelineRequest[HTTPRequestTypeVar]
66-
) -> PipelineResponse[HTTPRequestTypeVar, HTTPResponseTypeVar]:
59+
def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
6760
"""Abstract send method for a synchronous pipeline. Mutates the request.
6861
6962
Context content is dependent on the HttpTransport.
@@ -75,7 +68,7 @@ def send(
7568
"""
7669

7770

78-
class SansIOHTTPPolicy(Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]):
71+
class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
7972
"""Represents a sans I/O policy.
8073
8174
SansIOHTTPPolicy is a base class for policies that only modify or
@@ -87,7 +80,7 @@ class SansIOHTTPPolicy(Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]):
8780
but they will then be tied to AsyncPipeline usage.
8881
"""
8982

90-
def on_request(self, request: PipelineRequest[HTTPRequestTypeVar]) -> Union[None, Awaitable[None]]:
83+
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> Union[None, Awaitable[None]]:
9184
"""Is executed before sending the request from next policy.
9285
9386
:param request: Request to be modified before sent from next policy.
@@ -96,8 +89,8 @@ def on_request(self, request: PipelineRequest[HTTPRequestTypeVar]) -> Union[None
9689

9790
def on_response(
9891
self,
99-
request: PipelineRequest[HTTPRequestTypeVar],
100-
response: PipelineResponse[HTTPRequestTypeVar, HTTPResponseTypeVar],
92+
request: PipelineRequest[HTTPRequestType],
93+
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
10194
) -> Union[None, Awaitable[None]]:
10295
"""Is executed after the request comes back from the policy.
10396
@@ -109,7 +102,7 @@ def on_response(
109102

110103
def on_exception(
111104
self,
112-
request: PipelineRequest[HTTPRequestTypeVar], # pylint: disable=unused-argument
105+
request: PipelineRequest[HTTPRequestType], # pylint: disable=unused-argument
113106
) -> None:
114107
"""Is executed if an exception is raised while executing the next policy.
115108
@@ -129,7 +122,7 @@ def on_exception(
129122
return
130123

131124

132-
class RequestHistory(Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]):
125+
class RequestHistory(Generic[HTTPRequestType, HTTPResponseType]):
133126
"""A container for an attempted request and the applicable response.
134127
135128
This is used to document requests/responses that resulted in redirected/retried requests.
@@ -144,12 +137,12 @@ class RequestHistory(Generic[HTTPRequestTypeVar, HTTPResponseTypeVar]):
144137

145138
def __init__(
146139
self,
147-
http_request: HTTPRequestTypeVar,
148-
http_response: Optional[HTTPResponseTypeVar] = None,
140+
http_request: HTTPRequestType,
141+
http_response: Optional[HTTPResponseType] = None,
149142
error: Optional[Exception] = None,
150143
context: Optional[Dict[str, Any]] = None,
151144
) -> None:
152-
self.http_request = copy.deepcopy(http_request)
153-
self.http_response = http_response
154-
self.error = error
155-
self.context = context
145+
self.http_request: HTTPRequestType = copy.deepcopy(http_request)
146+
self.http_response: Optional[HTTPResponseType] = http_response
147+
self.error: Optional[Exception] = error
148+
self.context: Optional[Dict[str, Any]] = context

sdk/core/azure-core/azure/core/pipeline/policies/_base_async.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,27 @@
2525
# --------------------------------------------------------------------------
2626
import abc
2727

28-
from typing import TYPE_CHECKING, Generic, TypeVar
29-
28+
from typing import Generic, TypeVar
3029
from .. import PipelineRequest, PipelineResponse
3130

32-
if TYPE_CHECKING:
33-
from ..transport._base_async import AsyncHttpTransport
34-
35-
36-
AsyncHTTPResponseTypeVar = TypeVar("AsyncHTTPResponseTypeVar")
37-
HTTPResponseTypeVar = TypeVar("HTTPResponseTypeVar")
38-
HTTPRequestTypeVar = TypeVar("HTTPRequestTypeVar")
31+
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
32+
HTTPResponseType = TypeVar("HTTPResponseType")
33+
HTTPRequestType = TypeVar("HTTPRequestType")
3934

4035

41-
class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestTypeVar, AsyncHTTPResponseTypeVar]):
36+
class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]):
4237
"""An async HTTP policy ABC.
4338
4439
Use with an asynchronous pipeline.
4540
"""
4641

47-
next: "AsyncHTTPPolicy[HTTPRequestTypeVar, AsyncHTTPResponseTypeVar]"
42+
next: "AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]"
4843
"""Pointer to the next policy or a transport (wrapped as a policy). Will be set at pipeline creation."""
4944

5045
@abc.abstractmethod
5146
async def send(
52-
self, request: PipelineRequest[HTTPRequestTypeVar]
53-
) -> PipelineResponse[HTTPRequestTypeVar, AsyncHTTPResponseTypeVar]:
47+
self, request: PipelineRequest[HTTPRequestType]
48+
) -> PipelineResponse[HTTPRequestType, AsyncHTTPResponseType]:
5449
"""Abstract send method for a asynchronous pipeline. Mutates the request.
5550
5651
Context content is dependent on the HttpTransport.

0 commit comments

Comments
 (0)