Skip to content

Commit ba102ee

Browse files
Copilotxiangyan99
andcommitted
Implement claims challenge error for AzurePowerShellCredential
Co-authored-by: xiangyan99 <[email protected]>
1 parent 357bba6 commit ba102ee

File tree

4 files changed

+133
-2
lines changed

4 files changed

+133
-2
lines changed

sdk/identity/azure-identity/azure/identity/_credentials/azure_powershell.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def close(self) -> None:
112112
def get_token(
113113
self,
114114
*scopes: str,
115-
claims: Optional[str] = None, # pylint:disable=unused-argument
115+
claims: Optional[str] = None,
116116
tenant_id: Optional[str] = None,
117117
**kwargs: Any,
118118
) -> AccessToken:
@@ -136,6 +136,12 @@ def get_token(
136136
receive an access token
137137
"""
138138

139+
# Check if claims challenge is provided
140+
if claims:
141+
raise CredentialUnavailableError(
142+
message="Fail to get token, please run Connect-AzAccount --ClaimsChallenge"
143+
)
144+
139145
options: TokenRequestOptions = {}
140146
if tenant_id:
141147
options["tenant_id"] = tenant_id
@@ -170,6 +176,12 @@ def _get_token_base(
170176
self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any
171177
) -> AccessTokenInfo:
172178

179+
# Check if claims challenge is provided
180+
if options and options.get("claims"):
181+
raise CredentialUnavailableError(
182+
message="Fail to get token, please run Connect-AzAccount --ClaimsChallenge"
183+
)
184+
173185
tenant_id = options.get("tenant_id") if options else None
174186
if tenant_id:
175187
validate_tenant_id(tenant_id)

sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_powershell.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
async def get_token(
5959
self,
6060
*scopes: str,
61-
claims: Optional[str] = None, # pylint:disable=unused-argument
61+
claims: Optional[str] = None,
6262
tenant_id: Optional[str] = None,
6363
**kwargs: Any,
6464
) -> AccessToken:
@@ -80,6 +80,13 @@ async def get_token(
8080
:raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't
8181
receive an access token
8282
"""
83+
84+
# Check if claims challenge is provided
85+
if claims:
86+
raise CredentialUnavailableError(
87+
message="Fail to get token, please run Connect-AzAccount --ClaimsChallenge"
88+
)
89+
8390
# only ProactorEventLoop supports subprocesses on Windows (and it isn't the default loop on Python < 3.8)
8491
if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop):
8592
return _SyncCredential().get_token(*scopes, tenant_id=tenant_id, **kwargs)
@@ -112,13 +119,27 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio
112119
:raises ~azure.core.exceptions.ClientAuthenticationError: the credential invoked Azure PowerShell but didn't
113120
receive an access token
114121
"""
122+
123+
# Check if claims challenge is provided
124+
if options and options.get("claims"):
125+
raise CredentialUnavailableError(
126+
message="Fail to get token, please run Connect-AzAccount --ClaimsChallenge"
127+
)
128+
115129
if sys.platform.startswith("win") and not isinstance(asyncio.get_event_loop(), asyncio.ProactorEventLoop):
116130
return _SyncCredential().get_token_info(*scopes, options=options)
117131
return await self._get_token_base(*scopes, options=options)
118132

119133
async def _get_token_base(
120134
self, *scopes: str, options: Optional[TokenRequestOptions] = None, **kwargs: Any
121135
) -> AccessTokenInfo:
136+
137+
# Check if claims challenge is provided
138+
if options and options.get("claims"):
139+
raise CredentialUnavailableError(
140+
message="Fail to get token, please run Connect-AzAccount --ClaimsChallenge"
141+
)
142+
122143
tenant_id = options.get("tenant_id") if options else None
123144
if tenant_id:
124145
validate_tenant_id(tenant_id)

sdk/identity/azure-identity/tests/test_powershell_credential.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,52 @@ def fake_Popen(command, **_):
380380
kwargs = {"options": kwargs}
381381
token = getattr(credential, get_token_method)("scope", **kwargs)
382382
assert token.token == expected_token
383+
384+
385+
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
386+
def test_claims_challenge_error(get_token_method):
387+
"""The credential should raise CredentialUnavailableError when claims challenge is provided"""
388+
389+
if get_token_method == "get_token":
390+
# Test claims parameter in get_token method
391+
with pytest.raises(CredentialUnavailableError) as exc_info:
392+
getattr(AzurePowerShellCredential(), get_token_method)("scope", claims="some-claims")
393+
assert "Fail to get token, please run Connect-AzAccount --ClaimsChallenge" in str(exc_info.value)
394+
else:
395+
# Test claims in options for get_token_info method
396+
with pytest.raises(CredentialUnavailableError) as exc_info:
397+
getattr(AzurePowerShellCredential(), get_token_method)("scope", options={"claims": "some-claims"})
398+
assert "Fail to get token, please run Connect-AzAccount --ClaimsChallenge" in str(exc_info.value)
399+
400+
401+
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
402+
def test_empty_claims_no_error(get_token_method):
403+
"""The credential should not raise error for empty or None claims"""
404+
405+
# Mock successful token response
406+
expected_access_token = "access"
407+
expected_expires_on = 1617923581
408+
stdout = "azsdk%{}%{}".format(expected_access_token, expected_expires_on)
409+
410+
Popen = get_mock_Popen(stdout=stdout)
411+
with patch(POPEN, Popen):
412+
if get_token_method == "get_token":
413+
# Test None claims parameter in get_token method
414+
token = getattr(AzurePowerShellCredential(), get_token_method)("scope", claims=None)
415+
assert token.token == expected_access_token
416+
417+
# Test empty string claims
418+
token = getattr(AzurePowerShellCredential(), get_token_method)("scope", claims="")
419+
assert token.token == expected_access_token
420+
else:
421+
# Test None claims in options for get_token_info method
422+
token = getattr(AzurePowerShellCredential(), get_token_method)("scope", options={"claims": None})
423+
assert token.token == expected_access_token
424+
425+
# Test empty string claims in options
426+
token = getattr(AzurePowerShellCredential(), get_token_method)("scope", options={"claims": ""})
427+
assert token.token == expected_access_token
428+
429+
# Test missing claims key in options
430+
token = getattr(AzurePowerShellCredential(), get_token_method)("scope", options={})
431+
assert token.token == expected_access_token

sdk/identity/azure-identity/tests/test_powershell_credential_async.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,52 @@ async def fake_exec(*args, **_):
389389
kwargs = {"options": kwargs}
390390
token = await getattr(credential, get_token_method)("scope", **kwargs)
391391
assert token.token == expected_token
392+
393+
394+
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
395+
async def test_claims_challenge_error(get_token_method):
396+
"""The credential should raise CredentialUnavailableError when claims challenge is provided"""
397+
398+
if get_token_method == "get_token":
399+
# Test claims parameter in get_token method
400+
with pytest.raises(CredentialUnavailableError) as exc_info:
401+
await getattr(AzurePowerShellCredential(), get_token_method)("scope", claims="some-claims")
402+
assert "Fail to get token, please run Connect-AzAccount --ClaimsChallenge" in str(exc_info.value)
403+
else:
404+
# Test claims in options for get_token_info method
405+
with pytest.raises(CredentialUnavailableError) as exc_info:
406+
await getattr(AzurePowerShellCredential(), get_token_method)("scope", options={"claims": "some-claims"})
407+
assert "Fail to get token, please run Connect-AzAccount --ClaimsChallenge" in str(exc_info.value)
408+
409+
410+
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
411+
async def test_empty_claims_no_error(get_token_method):
412+
"""The credential should not raise error for empty or None claims"""
413+
414+
# Mock successful token response
415+
expected_access_token = "access"
416+
expected_expires_on = 1617923581
417+
stdout = "azsdk%{}%{}".format(expected_access_token, expected_expires_on)
418+
419+
exec_mock = get_mock_exec(stdout=stdout)
420+
with patch(CREATE_SUBPROCESS_EXEC, exec_mock):
421+
if get_token_method == "get_token":
422+
# Test None claims parameter in get_token method
423+
token = await getattr(AzurePowerShellCredential(), get_token_method)("scope", claims=None)
424+
assert token.token == expected_access_token
425+
426+
# Test empty string claims
427+
token = await getattr(AzurePowerShellCredential(), get_token_method)("scope", claims="")
428+
assert token.token == expected_access_token
429+
else:
430+
# Test None claims in options for get_token_info method
431+
token = await getattr(AzurePowerShellCredential(), get_token_method)("scope", options={"claims": None})
432+
assert token.token == expected_access_token
433+
434+
# Test empty string claims in options
435+
token = await getattr(AzurePowerShellCredential(), get_token_method)("scope", options={"claims": ""})
436+
assert token.token == expected_access_token
437+
438+
# Test missing claims key in options
439+
token = await getattr(AzurePowerShellCredential(), get_token_method)("scope", options={})
440+
assert token.token == expected_access_token

0 commit comments

Comments
 (0)