Skip to content

Commit 9881093

Browse files
authored
[Key Vault] Correctly handle token refreshes in AD FS (#33634)
1 parent 0899474 commit 9881093

File tree

10 files changed

+55
-18
lines changed

10 files changed

+55
-18
lines changed

sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from . import http_challenge_cache as ChallengeCache
2727
from .challenge_auth_policy import _enforce_tls, _update_challenge
2828

29-
3029
class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
3130
"""Policy for handling HTTP authentication challenges.
3231
@@ -49,7 +48,11 @@ async def on_request(self, request: PipelineRequest) -> None:
4948
if self._need_new_token():
5049
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
5150
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
52-
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
51+
# Exclude tenant for AD FS authentication
52+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
53+
self._token = await self._credential.get_token(scope)
54+
else:
55+
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
5356

5457
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
5558
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
6666
def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
6767
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs)
6868
self._credential = credential
69-
self._token: "Optional[AccessToken]" = None
69+
self._token: Optional[AccessToken] = None
7070
self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True)
7171

7272
def on_request(self, request: PipelineRequest) -> None:
@@ -77,7 +77,11 @@ def on_request(self, request: PipelineRequest) -> None:
7777
if self._need_new_token:
7878
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
7979
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
80-
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
80+
# Exclude tenant for AD FS authentication
81+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
82+
self._token = self._credential.get_token(scope)
83+
else:
84+
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
8185

8286
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
8387
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ async def on_request(self, request: PipelineRequest) -> None:
4848
if self._need_new_token():
4949
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
5050
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
51-
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
51+
# Exclude tenant for AD FS authentication
52+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
53+
self._token = await self._credential.get_token(scope)
54+
else:
55+
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
5256

5357
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
5458
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def on_request(self, request: PipelineRequest) -> None:
7777
if self._need_new_token:
7878
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
7979
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
80-
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
80+
# Exclude tenant for AD FS authentication
81+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
82+
self._token = self._credential.get_token(scope)
83+
else:
84+
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
8185

8286
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
8387
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ async def on_request(self, request: PipelineRequest) -> None:
4848
if self._need_new_token():
4949
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
5050
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
51-
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
51+
# Exclude tenant for AD FS authentication
52+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
53+
self._token = await self._credential.get_token(scope)
54+
else:
55+
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
5256

5357
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
5458
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def on_request(self, request: PipelineRequest) -> None:
7777
if self._need_new_token:
7878
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
7979
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
80-
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
80+
# Exclude tenant for AD FS authentication
81+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
82+
self._token = self._credential.get_token(scope)
83+
else:
84+
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
8185

8286
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
8387
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def send(request):
262262
assert not request.body
263263
assert request.headers["Content-Length"] == "0"
264264
return challenge
265-
elif Requests.count == 2:
265+
elif Requests.count in (2, 3):
266266
# second request should be authorized according to challenge and have the expected content
267267
assert request.headers["Content-Length"]
268268
assert request.body == expected_content
@@ -276,13 +276,17 @@ def get_token(*_, **kwargs):
276276
return AccessToken(expected_token, 0)
277277

278278
credential = Mock(get_token=Mock(wraps=get_token))
279-
pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send))
279+
policy = ChallengeAuthPolicy(credential=credential)
280+
pipeline = Pipeline(policies=[policy], transport=Mock(send=send))
280281
request = HttpRequest("POST", get_random_url())
281282
request.set_bytes_body(expected_content)
282283
pipeline.run(request)
283-
284284
assert credential.get_token.call_count == 1
285285

286+
# Regression test: https://github.com/Azure/azure-sdk-for-python/issues/33621
287+
policy._token = None
288+
pipeline.run(request)
289+
286290
tenant = "tenant-id"
287291
# AD FS challenges have an unusual authority format; see https://github.com/Azure/azure-sdk-for-python/issues/28648
288292
endpoint = f"https://adfs.redmond.azurestack.corp.microsoft.com/adfs/{tenant}"

sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ async def send(request):
218218
assert not request.body
219219
assert request.headers["Content-Length"] == "0"
220220
return challenge
221-
elif Requests.count == 2:
221+
elif Requests.count in (2, 3):
222222
# second request should be authorized according to challenge and have the expected content
223223
assert request.headers["Content-Length"]
224224
assert request.body == expected_content
@@ -232,15 +232,17 @@ async def get_token(*_, **kwargs):
232232
return AccessToken(expected_token, 0)
233233

234234
credential = Mock(get_token=Mock(wraps=get_token))
235-
pipeline = AsyncPipeline(
236-
policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)
237-
)
235+
policy = AsyncChallengeAuthPolicy(credential=credential)
236+
pipeline = AsyncPipeline(policies=[policy], transport=Mock(send=send))
238237
request = HttpRequest("POST", get_random_url())
239238
request.set_bytes_body(expected_content)
240239
await pipeline.run(request)
241-
242240
assert credential.get_token.call_count == 1
243241

242+
# Regression test: https://github.com/Azure/azure-sdk-for-python/issues/33621
243+
policy._token = None
244+
await pipeline.run(request)
245+
244246
tenant = "tenant-id"
245247
# AD FS challenges have an unusual authority format; see https://github.com/Azure/azure-sdk-for-python/issues/28648
246248
endpoint = f"https://adfs.redmond.azurestack.corp.microsoft.com/adfs/{tenant}"

sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ async def on_request(self, request: PipelineRequest) -> None:
4848
if self._need_new_token():
4949
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
5050
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
51-
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
51+
# Exclude tenant for AD FS authentication
52+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
53+
self._token = await self._credential.get_token(scope)
54+
else:
55+
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
5256

5357
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
5458
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def on_request(self, request: PipelineRequest) -> None:
7777
if self._need_new_token:
7878
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
7979
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
80-
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
80+
# Exclude tenant for AD FS authentication
81+
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
82+
self._token = self._credential.get_token(scope)
83+
else:
84+
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
8185

8286
# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
8387
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore

0 commit comments

Comments
 (0)