Skip to content

Commit c6876ec

Browse files
swathipilpvaneck
andauthored
[corehttp] use typeguard for HTTPPolicy/SansIOHTTPPolicy check not isinstance (#34296)
* sansiohttppolicy inherits from protocol * fix sans io policy tests * fix mypy errors * fix httpresponse/requesttype is unbound in retry + contravariant sansio policy * remove extra import * fix pylint/black * address paul's comments * update sansiohttppolicy on_response/on_request to return None for mypy * add is_http_policy/is_sansio_policy typeguard check * add tests * black * add back elif policy * add typing to is__policy checks + tests * Apply suggestions from code review Co-authored-by: Paul Van Eck <[email protected]> --------- Co-authored-by: Paul Van Eck <[email protected]>
1 parent a4fb935 commit c6876ec

File tree

6 files changed

+145
-12
lines changed

6 files changed

+145
-12
lines changed

sdk/core/corehttp/corehttp/runtime/pipeline/_base.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import annotations
2727
import logging
2828
from typing import Generic, TypeVar, Union, Any, List, Optional, Iterable, ContextManager
29+
from typing_extensions import TypeGuard
2930

3031
from . import (
3132
PipelineRequest,
@@ -42,6 +43,18 @@
4243
_LOGGER = logging.getLogger(__name__)
4344

4445

46+
def is_http_policy(policy: object) -> TypeGuard[HTTPPolicy]:
47+
if hasattr(policy, "send"):
48+
return True
49+
return False
50+
51+
52+
def is_sansio_http_policy(policy: object) -> TypeGuard[SansIOHTTPPolicy]:
53+
if hasattr(policy, "on_request") and hasattr(policy, "on_response"):
54+
return True
55+
return False
56+
57+
4558
class _SansIOHTTPPolicyRunner(HTTPPolicy[HTTPRequestType, HTTPResponseType]):
4659
"""Sync implementation of the SansIO policy.
4760
@@ -123,10 +136,14 @@ def __init__(
123136
self._transport = transport
124137

125138
for policy in policies or []:
126-
if isinstance(policy, SansIOHTTPPolicy):
139+
if is_http_policy(policy):
140+
self._impl_policies.append(policy)
141+
elif is_sansio_http_policy(policy):
127142
self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
128143
elif policy:
129-
self._impl_policies.append(policy)
144+
raise AttributeError(
145+
f"'{type(policy)}' object has no attribute 'send' or both 'on_request' and 'on_response'."
146+
)
130147
for index in range(len(self._impl_policies) - 1):
131148
self._impl_policies[index].next = self._impl_policies[index + 1]
132149
if self._impl_policies:

sdk/core/corehttp/corehttp/runtime/pipeline/_base_async.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,27 @@
2424
#
2525
# --------------------------------------------------------------------------
2626
from __future__ import annotations
27+
import inspect
2728
from types import TracebackType
2829
from typing import Any, Union, Generic, TypeVar, List, Optional, Iterable, Type
29-
from typing_extensions import AsyncContextManager
30+
from typing_extensions import AsyncContextManager, TypeGuard
3031

3132
from . import PipelineRequest, PipelineResponse, PipelineContext
3233
from ..policies import AsyncHTTPPolicy, SansIOHTTPPolicy
34+
from ..pipeline._base import is_sansio_http_policy
3335
from ._tools_async import await_result as _await_result
3436
from ...transport import AsyncHttpTransport
3537

3638
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
3739
HTTPRequestType = TypeVar("HTTPRequestType")
3840

3941

42+
def is_async_http_policy(policy: object) -> TypeGuard[AsyncHTTPPolicy]:
43+
if hasattr(policy, "send") and inspect.iscoroutinefunction(policy.send):
44+
return True
45+
return False
46+
47+
4048
class _SansIOAsyncHTTPPolicyRunner(
4149
AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]
4250
): # pylint: disable=unsubscriptable-object
@@ -127,10 +135,14 @@ def __init__(
127135
self._transport = transport
128136

129137
for policy in policies or []:
130-
if isinstance(policy, SansIOHTTPPolicy):
138+
if is_async_http_policy(policy):
139+
self._impl_policies.append(policy)
140+
elif is_sansio_http_policy(policy):
131141
self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
132142
elif policy:
133-
self._impl_policies.append(policy)
143+
raise AttributeError(
144+
f"'{type(policy)}' object has no attribute 'send' or both 'on_request' and 'on_response'."
145+
)
134146
for index in range(len(self._impl_policies) - 1):
135147
self._impl_policies[index].next = self._impl_policies[index + 1]
136148
if self._impl_policies:

sdk/core/corehttp/tests/async_tests/test_authentication_async.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SansIOHTTPPolicy,
1717
)
1818
from corehttp.rest import HttpRequest
19+
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
1920
import pytest
2021

2122
pytestmark = pytest.mark.asyncio
@@ -93,12 +94,15 @@ async def get_token(*_, **__):
9394
get_token_calls += 1
9495
return expected_token
9596

97+
async def send_mock(_):
98+
return Mock(http_response=Mock(status_code=200))
99+
96100
credential = Mock(get_token=get_token)
97101
policies = [
98102
AsyncBearerTokenCredentialPolicy(credential, "scope"),
99-
Mock(send=Mock(return_value=get_completed_future(Mock()))),
103+
Mock(send=send_mock),
100104
]
101-
pipeline = AsyncPipeline(transport=Mock, policies=policies)
105+
pipeline = AsyncPipeline(transport=Mock(), policies=policies)
102106

103107
await pipeline.run(HttpRequest("GET", "https://spam.eggs"))
104108
assert get_token_calls == 1 # policy has no token at first request -> it should call get_token
@@ -111,7 +115,7 @@ async def get_token(*_, **__):
111115
expected_token = expired_token
112116
policies = [
113117
AsyncBearerTokenCredentialPolicy(credential, "scope"),
114-
Mock(send=lambda _: get_completed_future(Mock())),
118+
Mock(send=send_mock),
115119
]
116120
pipeline = AsyncPipeline(transport=Mock(), policies=policies)
117121

@@ -238,6 +242,27 @@ async def fake_send(*args, **kwargs):
238242
policy.on_exception.assert_called_once_with(policy.request)
239243

240244

245+
async def test_azure_core_sans_io_policy():
246+
"""Tests to see that we can use an azure.core SansIOHTTPPolicy with the corehttp Pipeline"""
247+
248+
class TestPolicy(AzureKeyCredentialPolicy):
249+
def __init__(self, *args, **kwargs):
250+
super(TestPolicy, self).__init__(*args, **kwargs)
251+
self.on_exception = Mock(return_value=False)
252+
self.on_request = Mock()
253+
254+
credential = Mock(
255+
get_token=Mock(return_value=get_completed_future(AccessToken("***", int(time.time()) + 3600))), key="key"
256+
)
257+
policy = TestPolicy(credential, "scope")
258+
transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200))))
259+
260+
pipeline = AsyncPipeline(transport=transport, policies=[policy])
261+
await pipeline.run(HttpRequest("GET", "https://localhost"))
262+
263+
policy.on_request.assert_called_once()
264+
265+
241266
def get_completed_future(result=None):
242267
fut = asyncio.Future()
243268
fut.set_result(result)

sdk/core/corehttp/tests/async_tests/test_pipeline_async.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
from typing import cast
7-
from unittest.mock import AsyncMock, PropertyMock
7+
from unittest.mock import AsyncMock, PropertyMock, Mock
88

99
from corehttp.rest import HttpRequest
1010
from corehttp.runtime import AsyncPipelineClient
@@ -49,6 +49,39 @@ async def __aexit__(self, exc_type, exc_value, traceback):
4949
await pipeline.run(req)
5050

5151

52+
def test_invalid_policy_error():
53+
# non-HTTPPolicy/non-SansIOHTTPPolicy should raise an error
54+
class FooPolicy:
55+
pass
56+
57+
# sync send method should raise an error
58+
class SyncSendPolicy:
59+
def send(self, request):
60+
pass
61+
62+
# only on_request should raise an error
63+
class OnlyOnRequestPolicy:
64+
def on_request(self, request):
65+
pass
66+
67+
# only on_response should raise an error
68+
class OnlyOnResponsePolicy:
69+
def on_response(self, request, response):
70+
pass
71+
72+
with pytest.raises(AttributeError):
73+
pipeline = AsyncPipeline(transport=Mock(), policies=[FooPolicy()])
74+
75+
with pytest.raises(AttributeError):
76+
pipeline = AsyncPipeline(transport=Mock(), policies=[SyncSendPolicy()])
77+
78+
with pytest.raises(AttributeError):
79+
pipeline = AsyncPipeline(transport=Mock(), policies=[OnlyOnRequestPolicy()])
80+
81+
with pytest.raises(AttributeError):
82+
pipeline = AsyncPipeline(transport=Mock(), policies=[OnlyOnResponsePolicy()])
83+
84+
5285
@pytest.mark.asyncio
5386
@pytest.mark.parametrize("transport", ASYNC_TRANSPORTS)
5487
async def test_transport_socket_timeout(transport):
@@ -95,7 +128,7 @@ async def test_basic_aiohttp_separate_session(port):
95128
@pytest.mark.asyncio
96129
async def test_retry_without_http_response():
97130
class NaughtyPolicy(AsyncHTTPPolicy):
98-
def send(*args):
131+
async def send(*args):
99132
raise BaseError("boo")
100133

101134
policies = [AsyncRetryPolicy(), NaughtyPolicy()]
@@ -107,11 +140,11 @@ def send(*args):
107140
@pytest.mark.asyncio
108141
async def test_add_custom_policy():
109142
class BooPolicy(AsyncHTTPPolicy):
110-
def send(*args):
143+
async def send(*args):
111144
raise BaseError("boo")
112145

113146
class FooPolicy(AsyncHTTPPolicy):
114-
def send(*args):
147+
async def send(*args):
115148
raise BaseError("boo")
116149

117150
retry_policy = AsyncRetryPolicy()

sdk/core/corehttp/tests/test_authentication.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ServiceKeyCredentialPolicy,
1616
)
1717
from corehttp.rest import HttpRequest
18+
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
1819
import pytest
1920

2021

@@ -251,6 +252,25 @@ def raise_the_second_time(*args, **kwargs):
251252
policy.on_exception.assert_called_once_with(policy.request)
252253

253254

255+
def test_azure_core_sans_io_policy():
256+
"""Tests to see that we can use an azure.core SansIOHTTPPolicy with the corehttp Pipeline"""
257+
258+
class TestPolicy(AzureKeyCredentialPolicy):
259+
def __init__(self, *args, **kwargs):
260+
super(TestPolicy, self).__init__(*args, **kwargs)
261+
self.on_exception = Mock(return_value=False)
262+
self.on_request = Mock()
263+
264+
credential = Mock(get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600)), key="key")
265+
policy = TestPolicy(credential, "scope")
266+
transport = Mock(send=Mock(return_value=Mock(status_code=200)))
267+
268+
pipeline = Pipeline(transport=transport, policies=[policy])
269+
pipeline.run(HttpRequest("GET", "https://localhost"))
270+
271+
policy.on_request.assert_called_once()
272+
273+
254274
def test_service_key_credential_policy():
255275
"""Tests to see if we can create an ServiceKeyCredentialPolicy"""
256276

sdk/core/corehttp/tests/test_pipeline.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66

7+
from unittest.mock import Mock
78
import json
89
from io import BytesIO
910
import xml.etree.ElementTree as ET
@@ -52,6 +53,31 @@ def __exit__(self, exc_type, exc_value, traceback):
5253
pipeline.run(req)
5354

5455

56+
def test_invalid_policy_error():
57+
# non-HTTPPolicy/non-SansIOHTTPPolicy should raise an error
58+
class FooPolicy:
59+
pass
60+
61+
# only on_request should raise an error
62+
class OnlyOnRequestPolicy:
63+
def on_request(self, request):
64+
pass
65+
66+
# only on_response should raise an error
67+
class OnlyOnResponsePolicy:
68+
def on_response(self, request, response):
69+
pass
70+
71+
with pytest.raises(AttributeError):
72+
pipeline = Pipeline(transport=Mock(), policies=[FooPolicy()])
73+
74+
with pytest.raises(AttributeError):
75+
pipeline = Pipeline(transport=Mock(), policies=[OnlyOnRequestPolicy()])
76+
77+
with pytest.raises(AttributeError):
78+
pipeline = Pipeline(transport=Mock(), policies=[OnlyOnResponsePolicy()])
79+
80+
5581
@pytest.mark.parametrize("transport", SYNC_TRANSPORTS)
5682
def test_transport_socket_timeout(transport):
5783
request = HttpRequest("GET", "https://bing.com")

0 commit comments

Comments
 (0)