Skip to content

Commit 55632a7

Browse files
xiangyan99pvaneck
andauthored
Add default impl to handle token challenges (Azure#37652)
* Add default impl to handle token challenges * update version * update * update * update * update * Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py Co-authored-by: Paul Van Eck <[email protected]> * Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py Co-authored-by: Paul Van Eck <[email protected]> * update * Update sdk/core/azure-core/tests/test_utils.py Co-authored-by: Paul Van Eck <[email protected]> * Update sdk/core/azure-core/azure/core/pipeline/policies/_utils.py Co-authored-by: Paul Van Eck <[email protected]> * update --------- Co-authored-by: Paul Van Eck <[email protected]>
1 parent 37a2e61 commit 55632a7

File tree

6 files changed

+225
-28
lines changed

6 files changed

+225
-28
lines changed

sdk/core/azure-core/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Release History
22

3-
## 1.31.1 (Unreleased)
3+
## 1.32.0 (Unreleased)
44

55
### Features Added
66

7+
- Added a default implementation to handle token challenges in `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy`.
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/core/azure-core/azure/core/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
# regenerated.
1010
# --------------------------------------------------------------------------
1111

12-
VERSION = "1.31.1"
12+
VERSION = "1.32.0"

sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
import time
7+
import base64
78
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
89
from azure.core.credentials import (
910
TokenCredential,
@@ -19,6 +20,7 @@
1920
from azure.core.rest import HttpResponse, HttpRequest
2021
from . import HTTPPolicy, SansIOHTTPPolicy
2122
from ...exceptions import ServiceRequestError
23+
from ._utils import get_challenge_parameter
2224

2325
if TYPE_CHECKING:
2426

@@ -82,13 +84,7 @@ def _need_new_token(self) -> bool:
8284
refresh_on = getattr(self._token, "refresh_on", None)
8385
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
8486

85-
def _request_token(self, *scopes: str, **kwargs: Any) -> None:
86-
"""Request a new token from the credential.
87-
88-
This will call the credential's appropriate method to get a token and store it in the policy.
89-
90-
:param str scopes: The type of access needed.
91-
"""
87+
def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
9288
if self._enable_cae:
9389
kwargs.setdefault("enable_cae", self._enable_cae)
9490

@@ -99,9 +95,17 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> None:
9995
if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
10096
options[key] = kwargs.pop(key) # type: ignore[literal-required]
10197

102-
self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
103-
else:
104-
self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
98+
return cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
99+
return cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)
100+
101+
def _request_token(self, *scopes: str, **kwargs: Any) -> None:
102+
"""Request a new token from the credential.
103+
104+
This will call the credential's appropriate method to get a token and store it in the policy.
105+
106+
:param str scopes: The type of access needed.
107+
"""
108+
self._token = self._get_token(*scopes, **kwargs)
105109

106110

107111
class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
@@ -191,6 +195,22 @@ def on_challenge(
191195
:rtype: bool
192196
"""
193197
# pylint:disable=unused-argument
198+
headers = response.http_response.headers
199+
error = get_challenge_parameter(headers, "Bearer", "error")
200+
if error == "insufficient_claims":
201+
encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
202+
if not encoded_claims:
203+
return False
204+
try:
205+
padding_needed = -len(encoded_claims) % 4
206+
claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
207+
if claims:
208+
token = self._get_token(*self._scopes, claims=claims)
209+
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
210+
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
211+
return True
212+
except Exception: # pylint:disable=broad-except
213+
return False
194214
return False
195215

196216
def on_response(

sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
import time
7+
import base64
78
from typing import Any, Awaitable, Optional, cast, TypeVar, Union
89

910
from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
@@ -23,6 +24,7 @@
2324
)
2425
from azure.core.rest import AsyncHttpResponse, HttpRequest
2526
from azure.core.utils._utils import get_running_async_lock
27+
from ._utils import get_challenge_parameter
2628

2729
from .._tools_async import await_result
2830

@@ -138,6 +140,22 @@ async def on_challenge(
138140
:rtype: bool
139141
"""
140142
# pylint:disable=unused-argument
143+
headers = response.http_response.headers
144+
error = get_challenge_parameter(headers, "Bearer", "error")
145+
if error == "insufficient_claims":
146+
encoded_claims = get_challenge_parameter(headers, "Bearer", "claims")
147+
if not encoded_claims:
148+
return False
149+
try:
150+
padding_needed = -len(encoded_claims) % 4
151+
claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
152+
if claims:
153+
token = await self._get_token(*self._scopes, claims=claims)
154+
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
155+
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
156+
return True
157+
except Exception: # pylint:disable=broad-except
158+
return False
141159
return False
142160

143161
def on_response(
@@ -169,13 +187,7 @@ def _need_new_token(self) -> bool:
169187
refresh_on = getattr(self._token, "refresh_on", None)
170188
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
171189

172-
async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
173-
"""Request a new token from the credential.
174-
175-
This will call the credential's appropriate method to get a token and store it in the policy.
176-
177-
:param str scopes: The type of access needed.
178-
"""
190+
async def _get_token(self, *scopes: str, **kwargs: Any) -> Union["AccessToken", "AccessTokenInfo"]:
179191
if self._enable_cae:
180192
kwargs.setdefault("enable_cae", self._enable_cae)
181193

@@ -186,14 +198,22 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
186198
if key in TokenRequestOptions.__annotations__: # pylint: disable=no-member
187199
options[key] = kwargs.pop(key) # type: ignore[literal-required]
188200

189-
self._token = await await_result(
201+
return await await_result(
190202
cast(AsyncSupportsTokenInfo, self._credential).get_token_info,
191203
*scopes,
192204
options=options,
193205
)
194-
else:
195-
self._token = await await_result(
196-
cast(AsyncTokenCredential, self._credential).get_token,
197-
*scopes,
198-
**kwargs,
199-
)
206+
return await await_result(
207+
cast(AsyncTokenCredential, self._credential).get_token,
208+
*scopes,
209+
**kwargs,
210+
)
211+
212+
async def _request_token(self, *scopes: str, **kwargs: Any) -> None:
213+
"""Request a new token from the credential.
214+
215+
This will call the credential's appropriate method to get a token and store it in the policy.
216+
217+
:param str scopes: The type of access needed.
218+
"""
219+
self._token = await self._get_token(*scopes, **kwargs)

sdk/core/azure-core/azure/core/pipeline/policies/_utils.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# --------------------------------------------------------------------------
2626
import datetime
2727
import email.utils
28-
from typing import Optional, cast, Union
28+
from typing import Optional, cast, Union, Tuple
2929
from urllib.parse import urlparse
3030

3131
from azure.core.pipeline.transport import (
@@ -102,3 +102,103 @@ def get_domain(url: str) -> str:
102102
:return: The domain of the url.
103103
"""
104104
return str(urlparse(url).netloc).lower()
105+
106+
107+
def get_challenge_parameter(headers, challenge_scheme: str, challenge_parameter: str) -> Optional[str]:
108+
"""
109+
Parses the specified parameter from a challenge header found in the response.
110+
111+
:param dict[str, str] headers: The response headers to parse.
112+
:param str challenge_scheme: The challenge scheme containing the challenge parameter, e.g., "Bearer".
113+
:param str challenge_parameter: The parameter key name to search for.
114+
:return: The value of the parameter name if found.
115+
:rtype: str or None
116+
"""
117+
header_value = headers.get("WWW-Authenticate")
118+
if not header_value:
119+
return None
120+
121+
scheme = challenge_scheme
122+
parameter = challenge_parameter
123+
header_span = header_value
124+
125+
# Iterate through each challenge value.
126+
while True:
127+
challenge = get_next_challenge(header_span)
128+
if not challenge:
129+
break
130+
challenge_key, header_span = challenge
131+
if challenge_key.lower() != scheme.lower():
132+
continue
133+
# Enumerate each key-value parameter until we find the parameter key on the specified scheme challenge.
134+
while True:
135+
parameters = get_next_parameter(header_span)
136+
if not parameters:
137+
break
138+
key, value, header_span = parameters
139+
if key.lower() == parameter.lower():
140+
return value
141+
142+
return None
143+
144+
145+
def get_next_challenge(header_value: str) -> Optional[Tuple[str, str]]:
146+
"""
147+
Iterates through the challenge schemes present in a challenge header.
148+
149+
:param str header_value: The header value which will be sliced to remove the first parsed challenge key.
150+
:return: The parsed challenge scheme and the remaining header value.
151+
:rtype: tuple[str, str] or None
152+
"""
153+
header_value = header_value.lstrip(" ")
154+
end_of_challenge_key = header_value.find(" ")
155+
156+
if end_of_challenge_key < 0:
157+
return None
158+
159+
challenge_key = header_value[:end_of_challenge_key]
160+
header_value = header_value[end_of_challenge_key + 1 :]
161+
162+
return challenge_key, header_value
163+
164+
165+
def get_next_parameter(header_value: str, separator: str = "=") -> Optional[Tuple[str, str, str]]:
166+
"""
167+
Iterates through a challenge header value to extract key-value parameters.
168+
169+
:param str header_value: The header value after being parsed by get_next_challenge.
170+
:param str separator: The challenge parameter key-value pair separator, default is '='.
171+
:return: The next available challenge parameter as a tuple (param_key, param_value, remaining header_value).
172+
:rtype: tuple[str, str, str] or None
173+
"""
174+
space_or_comma = " ,"
175+
header_value = header_value.lstrip(space_or_comma)
176+
177+
next_space = header_value.find(" ")
178+
next_separator = header_value.find(separator)
179+
180+
if next_space < next_separator and next_space != -1:
181+
return None
182+
183+
if next_separator < 0:
184+
return None
185+
186+
param_key = header_value[:next_separator].strip()
187+
header_value = header_value[next_separator + 1 :]
188+
189+
quote_index = header_value.find('"')
190+
191+
if quote_index >= 0:
192+
header_value = header_value[quote_index + 1 :]
193+
param_value = header_value[: header_value.find('"')]
194+
else:
195+
trailing_delimiter_index = header_value.find(" ")
196+
if trailing_delimiter_index >= 0:
197+
param_value = header_value[:trailing_delimiter_index]
198+
else:
199+
param_value = header_value
200+
201+
if header_value != param_value:
202+
header_value = header_value[len(param_value) + 1 :]
203+
204+
return param_key, param_value, header_value

sdk/core/azure-core/tests/test_utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from azure.core.utils import case_insensitive_dict
1010
from azure.core.utils._utils import get_running_async_lock
11-
from azure.core.pipeline.policies._utils import parse_retry_after
11+
from azure.core.pipeline.policies._utils import parse_retry_after, get_challenge_parameter
1212

1313

1414
@pytest.fixture()
@@ -146,3 +146,58 @@ def test_parse_retry_after():
146146
assert ret == 0
147147
ret = parse_retry_after("0.9")
148148
assert ret == 0.9
149+
150+
151+
def test_get_challenge_parameter():
152+
headers = {
153+
"WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'
154+
}
155+
assert (
156+
get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id"
157+
)
158+
assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net"
159+
assert get_challenge_parameter(headers, "Bearer", "foo") is None
160+
161+
headers = {
162+
"WWW-Authenticate": 'Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="'
163+
}
164+
assert (
165+
get_challenge_parameter(headers, "Bearer", "authorization_uri")
166+
== "https://login.microsoftonline.com/common/oauth2/authorize"
167+
)
168+
assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims"
169+
assert (
170+
get_challenge_parameter(headers, "Bearer", "claims")
171+
== "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="
172+
)
173+
174+
175+
def test_get_challenge_parameter_not_found():
176+
headers = {
177+
"WWW-Authenticate": 'Pop authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'
178+
}
179+
assert get_challenge_parameter(headers, "Bearer", "resource") is None
180+
181+
182+
def test_get_multi_challenge_parameter():
183+
headers = {
184+
"WWW-Authenticate": 'Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net" Bearer authorization_uri="https://login.microsoftonline.com/tenant-id", resource="https://vault.azure.net"'
185+
}
186+
assert (
187+
get_challenge_parameter(headers, "Bearer", "authorization_uri") == "https://login.microsoftonline.com/tenant-id"
188+
)
189+
assert get_challenge_parameter(headers, "Bearer", "resource") == "https://vault.azure.net"
190+
assert get_challenge_parameter(headers, "Bearer", "foo") is None
191+
192+
headers = {
193+
"WWW-Authenticate": 'Digest realm="[email protected]", qop="auth,auth-int", nonce="123456abcdefg", opaque="123456", Bearer realm="", authorization_uri="https://login.microsoftonline.com/common/oauth2/authorize", error="insufficient_claims", claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="'
194+
}
195+
assert (
196+
get_challenge_parameter(headers, "Bearer", "authorization_uri")
197+
== "https://login.microsoftonline.com/common/oauth2/authorize"
198+
)
199+
assert get_challenge_parameter(headers, "Bearer", "error") == "insufficient_claims"
200+
assert (
201+
get_challenge_parameter(headers, "Bearer", "claims")
202+
== "eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwidmFsdWUiOiIxNzI2MDc3NTk1In0sInhtc19jYWVlcnJvciI6eyJ2YWx1ZSI6IjEwMDEyIn19fQ=="
203+
)

0 commit comments

Comments
 (0)