Skip to content

Commit e588964

Browse files
pvaneckscbedd
andauthored
[Core] Update async auth policy lock logic (#33282)
Sometimes users might be using trio and its mechanisms for running tasks concurrently. This updates the lock initilization logic to check if the user is in an asyncio event loop and sets the lock accordingly. If not, assume trio. We don't expect users to be in anything else other than asyncio or trio event loops. A utility function was added that tries to determine the current async library being used. Signed-off-by: Paul Van Eck <[email protected]> Co-authored-by: Scott Beddall (from Dev Box) <[email protected]>
1 parent 5a3ab93 commit e588964

File tree

6 files changed

+93
-5
lines changed

6 files changed

+93
-5
lines changed

sdk/core/azure-core/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
### Other Changes
1212

13+
- Removed dependency on `anyio`. #33282
14+
1315
## 1.29.6 (2023-12-14)
1416

1517
### Bugs Fixed

sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import time
77
from typing import TYPE_CHECKING, Any, Awaitable, Optional, cast, TypeVar
88

9-
from anyio import Lock
109
from azure.core.credentials import AccessToken
1110
from azure.core.pipeline import PipelineRequest, PipelineResponse
1211
from azure.core.pipeline.policies import AsyncHTTPPolicy
@@ -15,6 +14,7 @@
1514
)
1615
from azure.core.pipeline.transport import AsyncHttpResponse as LegacyAsyncHttpResponse, HttpRequest as LegacyHttpRequest
1716
from azure.core.rest import AsyncHttpResponse, HttpRequest
17+
from azure.core.utils._utils import get_running_async_lock
1818

1919
from .._tools_async import await_result
2020

@@ -38,11 +38,17 @@ class AsyncBearerTokenCredentialPolicy(AsyncHTTPPolicy[HTTPRequestType, AsyncHTT
3838
def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: Any) -> None:
3939
super().__init__()
4040
self._credential = credential
41-
self._lock = Lock()
4241
self._scopes = scopes
42+
self._lock_instance = None
4343
self._token: Optional["AccessToken"] = None
4444
self._enable_cae: bool = kwargs.get("enable_cae", False)
4545

46+
@property
47+
def _lock(self):
48+
if self._lock_instance is None:
49+
self._lock_instance = get_running_async_lock()
50+
return self._lock_instance
51+
4652
async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
4753
"""Adds a bearer token Authorization header to request and sends request to next policy.
4854

sdk/core/azure-core/azure/core/utils/_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
# license information.
66
# --------------------------------------------------------------------------
77
import datetime
8+
import sys
89
from typing import (
910
Any,
11+
AsyncContextManager,
1012
Iterable,
1113
Iterator,
1214
Mapping,
@@ -161,3 +163,26 @@ def __eq__(self, other: Any) -> bool:
161163

162164
def __repr__(self) -> str:
163165
return str(dict(self.items()))
166+
167+
168+
def get_running_async_lock() -> AsyncContextManager:
169+
"""Get a lock instance from the async library that the current context is running under.
170+
171+
:return: An instance of the running async library's Lock class.
172+
:rtype: AsyncContextManager
173+
:raises: RuntimeError if the current context is not running under an async library.
174+
"""
175+
176+
try:
177+
import asyncio
178+
179+
# Check if we are running in an asyncio event loop.
180+
asyncio.get_running_loop()
181+
return asyncio.Lock()
182+
except RuntimeError as err:
183+
# Otherwise, assume we are running in a trio event loop if it has already been imported.
184+
if "trio" in sys.modules:
185+
import trio # pylint: disable=networking-import-outside-azure-core-transport
186+
187+
return trio.Lock()
188+
raise RuntimeError("An asyncio or trio event loop is required.") from err

sdk/core/azure-core/setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
},
7070
python_requires=">=3.7",
7171
install_requires=[
72-
"anyio>=3.0,<5.0",
7372
"requests>=2.21.0",
7473
"six>=1.11.0",
7574
"typing-extensions>=4.6.0",

sdk/core/azure-core/tests/async_tests/test_authentication_async.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
import asyncio
7+
import sys
78
import time
8-
from unittest.mock import Mock
9+
from unittest.mock import Mock, patch
910
from requests import Response
1011

1112
from azure.core.credentials import AccessToken
@@ -20,11 +21,12 @@
2021
)
2122
from azure.core.pipeline.transport import AsyncHttpTransport, HttpRequest
2223
import pytest
24+
import trio
2325

24-
pytestmark = pytest.mark.asyncio
2526
from utils import HTTP_REQUESTS
2627

2728

29+
@pytest.mark.asyncio
2830
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
2931
async def test_bearer_policy_adds_header(http_request):
3032
"""The bearer token policy should add a header containing a token from its credential"""
@@ -54,6 +56,7 @@ async def get_token(*_, **__):
5456
assert get_token_calls == 1
5557

5658

59+
@pytest.mark.asyncio
5760
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
5861
async def test_bearer_policy_send(http_request):
5962
"""The bearer token policy should invoke the next policy's send method and return the result"""
@@ -72,6 +75,7 @@ async def verify_request(request):
7275
assert response is expected_response
7376

7477

78+
@pytest.mark.asyncio
7579
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
7680
async def test_bearer_policy_sync_send(http_request):
7781
"""The bearer token policy should invoke the next policy's send method and return the result"""
@@ -90,6 +94,7 @@ async def verify_request(request):
9094
assert response is expected_response
9195

9296

97+
@pytest.mark.asyncio
9398
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
9499
async def test_bearer_policy_token_caching(http_request):
95100
good_for_one_hour = AccessToken("token", time.time() + 3600)
@@ -130,6 +135,7 @@ async def get_token(*_, **__):
130135
assert get_token_calls == 2 # token expired -> policy should call get_token
131136

132137

138+
@pytest.mark.asyncio
133139
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
134140
async def test_bearer_policy_optionally_enforces_https(http_request):
135141
"""HTTPS enforcement should be controlled by a keyword argument, and enabled by default"""
@@ -158,6 +164,7 @@ async def assert_option_popped(request, **kwargs):
158164
await pipeline.run(http_request("GET", "https://secure"))
159165

160166

167+
@pytest.mark.asyncio
161168
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
162169
async def test_bearer_policy_preserves_enforce_https_opt_out(http_request):
163170
"""The policy should use request context to preserve an opt out from https enforcement"""
@@ -175,6 +182,7 @@ def on_request(self, request):
175182
await pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False)
176183

177184

185+
@pytest.mark.asyncio
178186
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
179187
async def test_bearer_policy_context_unmodified_by_default(http_request):
180188
"""When no options for the policy accompany a request, the policy shouldn't add anything to the request context"""
@@ -192,6 +200,7 @@ def on_request(self, request):
192200
await pipeline.run(http_request("GET", "https://secure"))
193201

194202

203+
@pytest.mark.asyncio
195204
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
196205
async def test_bearer_policy_calls_sansio_methods(http_request):
197206
"""AsyncBearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOAsyncHTTPPolicyRunner"""
@@ -440,3 +449,24 @@ async def get_token(self, *scopes, **kwargs):
440449

441450
cred = TestTokenCredential()
442451
await cred.get_token("scope")
452+
453+
454+
@pytest.mark.asyncio
455+
async def test_async_token_credential_asyncio_lock():
456+
auth_policy = AsyncBearerTokenCredentialPolicy(Mock(), "scope")
457+
assert isinstance(auth_policy._lock, asyncio.Lock)
458+
459+
460+
@pytest.mark.trio
461+
async def test_async_token_credential_trio_lock():
462+
auth_policy = AsyncBearerTokenCredentialPolicy(Mock(), "scope")
463+
assert isinstance(auth_policy._lock, trio.Lock)
464+
465+
466+
def test_async_token_credential_sync():
467+
"""Verify that AsyncBearerTokenCredentialPolicy can be constructed in a synchronous context."""
468+
auth_policy = AsyncBearerTokenCredentialPolicy(Mock(), "scope")
469+
with patch.dict("sys.modules"):
470+
# Ensure trio isn't in sys.modules (i.e. imported).
471+
sys.modules.pop("trio", None)
472+
AsyncBearerTokenCredentialPolicy(Mock(), "scope")

sdk/core/azure-core/tests/test_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
import sys
6+
from unittest.mock import patch
7+
58
import pytest
69
from azure.core.utils import case_insensitive_dict
10+
from azure.core.utils._utils import get_running_async_lock
711

812

913
@pytest.fixture()
@@ -108,3 +112,25 @@ def test_case_iter():
108112

109113
for key in my_dict:
110114
assert key in keys
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_get_running_async_module_asyncio():
119+
import asyncio
120+
121+
assert isinstance(get_running_async_lock(), asyncio.Lock)
122+
123+
124+
@pytest.mark.trio
125+
async def test_get_running_async_module_trio():
126+
import trio
127+
128+
assert isinstance(get_running_async_lock(), trio.Lock)
129+
130+
131+
def test_get_running_async_module_sync():
132+
with patch.dict("sys.modules"):
133+
# Ensure trio isn't in sys.modules (i.e. imported).
134+
sys.modules.pop("trio", None)
135+
with pytest.raises(RuntimeError):
136+
get_running_async_lock()

0 commit comments

Comments
 (0)