Skip to content

Commit 71df7d7

Browse files
authored
[ACR] az acr login: Enforce using acr audience in aad token acquisition (#31798)
1 parent 3fa1e89 commit 71df7d7

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from ._constants import get_managed_sku
2929
from ._constants import ACR_AUDIENCE_RESOURCE_NAME
3030
from ._utils import get_registry_by_name, ResourceNotFound
31-
from .policy import acr_config_authentication_as_arm_show
3231
from ._format import add_timestamp
3332
from ._errors import CONNECTIVITY_TOOMANYREQUESTS_ERROR
3433

@@ -135,18 +134,14 @@ def _get_aad_token_after_challenge(cli_ctx,
135134
artifact_repository,
136135
permission,
137136
is_diagnostics_context,
138-
use_acr_audience,
139137
verify_user_permissions):
140138
authurl = urlparse(token_params['realm'])
141139
authhost = urlunparse((authurl[0], authurl[1], '/oauth2/exchange', '', '', ''))
142140

143141
from azure.cli.core._profile import Profile
144142
profile = Profile(cli_ctx=cli_ctx)
145143

146-
scope = None
147-
if use_acr_audience:
148-
logger.debug("Using ACR audience token for authentication")
149-
scope = "https://{}.azure.net".format(ACR_AUDIENCE_RESOURCE_NAME)
144+
scope = "https://{}.azure.net".format(ACR_AUDIENCE_RESOURCE_NAME)
150145

151146
# this might be a cross tenant scenario, so pass subscription to get_raw_token
152147
creds, _, tenant = profile.get_raw_token(subscription=get_subscription_id(cli_ctx),
@@ -267,7 +262,6 @@ def _get_aad_token(cli_ctx,
267262
artifact_repository=None,
268263
permission=None,
269264
is_diagnostics_context=False,
270-
use_acr_audience=False,
271265
verify_user_permissions=False):
272266
"""Obtains refresh and access tokens for an AAD-enabled registry. Will return the allowed actions if
273267
verify_user_permissions is set to True.
@@ -296,7 +290,6 @@ def _get_aad_token(cli_ctx,
296290
artifact_repository,
297291
permission,
298292
is_diagnostics_context,
299-
use_acr_audience,
300293
verify_user_permissions)
301294

302295

@@ -453,19 +446,12 @@ def _get_credentials(cmd, # pylint: disable=too-many-statements
453446
if not registry or registry.sku.name in get_managed_sku(cmd):
454447
logger.info("Attempting to retrieve AAD refresh token...")
455448
try:
456-
use_acr_audience = False
457-
458-
if registry:
459-
aad_auth_policy = acr_config_authentication_as_arm_show(cmd, registry_name, resource_group_name)
460-
use_acr_audience = (aad_auth_policy and aad_auth_policy.status == 'disabled')
461-
462449
return login_server, EMPTY_GUID, _get_aad_token(cli_ctx,
463450
login_server,
464451
only_refresh_token,
465452
repository,
466453
artifact_repository,
467-
permission,
468-
use_acr_audience=use_acr_audience)
454+
permission)
469455
except CLIError as e:
470456
raise_toomanyrequests_error(str(e))
471457
logger.warning("%s: %s", AAD_TOKEN_BASE_ERROR_MESSAGE, str(e))

src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,7 @@ def _core_token_scenarios(self, mock_get_raw_token, mock_requests_get, mock_requ
12061206

12071207
# Test get refresh token
12081208
get_login_credentials(cmd, registry_name, tenant_suffix=tenant_suffix)
1209+
self._validate_raw_token_request(mock_get_raw_token)
12091210
self._validate_refresh_token_request(mock_requests_get, mock_requests_post, login_server)
12101211

12111212
# Test get access token for container image repository
@@ -1237,6 +1238,9 @@ def _setup_mock_token_requests(self, mock_get_aad_token, mock_requests_get, mock
12371238
'access_token': TEST_ACR_ACCESS_TOKEN}).encode()
12381239
mock_requests_post.return_value = token_response
12391240

1241+
def _validate_raw_token_request(self, mock_get_raw_token):
1242+
mock_get_raw_token.assert_called_with(mock.ANY, resource="https://containerregistry.azure.net", subscription=mock.ANY)
1243+
12401244
def _validate_refresh_token_request(self, mock_requests_get, mock_requests_post, login_server):
12411245
mock_requests_get.assert_called_with('https://{}/v2/'.format(login_server), verify=mock.ANY)
12421246
mock_requests_post.assert_called_with(

0 commit comments

Comments
 (0)