Skip to content

Commit 4c0766e

Browse files
committed
Add unit test for _get_codeartifact_index
1 parent 455f247 commit 4c0766e

File tree

2 files changed

+134
-10
lines changed

2 files changed

+134
-10
lines changed

src/sagemaker/processing.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pathlib
2323
import logging
2424
from textwrap import dedent
25-
from typing import Dict, List, Optional, Union
25+
from typing import Any, Dict, List, Optional, Union
2626
from copy import copy
2727
import re
2828

@@ -1846,14 +1846,18 @@ def _pack_and_upload_code(
18461846

18471847
return s3_runproc_sh, inputs, job_name
18481848

1849-
def _get_codeartifact_index(self, codeartifact_repo_arn: str):
1849+
def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_client: Any = None):
18501850
"""
18511851
Build the authenticated codeartifact index url based on the arn provided
18521852
via codeartifact_repo_arn property following the form
18531853
# `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
18541854
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
18551855
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
1856-
:return: authenticated codeartifact index url
1856+
Args:
1857+
codeartifact_repo_arn: arn of the codeartifact repository
1858+
codeartifact_client: boto3 client for codeartifact (used for testing)
1859+
Returns:
1860+
authenticated codeartifact index url
18571861
"""
18581862

18591863
arn_regex = (
@@ -1862,7 +1866,7 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str):
18621866
)
18631867
m = re.match(arn_regex, codeartifact_repo_arn)
18641868
if not m:
1865-
raise Exception("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
1869+
raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
18661870
domain = m.group("domain")
18671871
owner = m.group("account")
18681872
repository = m.group("repository")
@@ -1877,10 +1881,12 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str):
18771881
region,
18781882
)
18791883
try:
1880-
client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region)
1881-
auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner)
1884+
if not codeartifact_client:
1885+
codeartifact_client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region)
1886+
1887+
auth_token_response = codeartifact_client.get_authorization_token(domain=domain, domainOwner=owner)
18821888
token = auth_token_response["authorizationToken"]
1883-
endpoint_response = client.get_repository_endpoint(
1889+
endpoint_response = codeartifact_client.get_repository_endpoint(
18841890
domain=domain, domainOwner=owner, repository=repository, format="pypi"
18851891
)
18861892
unauthenticated_index = endpoint_response["repositoryEndpoint"]
@@ -1893,9 +1899,9 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str):
18931899
unauthenticated_index,
18941900
),
18951901
)
1896-
except Exception:
1897-
logger.error("failed to configure pip to use codeartifact")
1898-
raise Exception("failed to configure pip to use codeartifact")
1902+
except Exception as e:
1903+
logger.error("failed to configure pip to use codeartifact: %s", e, exc_info=True)
1904+
raise RuntimeError("failed to configure pip to use codeartifact")
18991905

19001906
def _generate_framework_script(
19011907
self, user_script: str, codeartifact_repo_arn: str = None

tests/unit/test_processing.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from __future__ import absolute_import
1414

1515
import copy
16+
import datetime
1617

18+
import boto3
19+
from botocore.stub import Stubber
1720
import pytest
1821
from mock import Mock, patch, MagicMock
1922
from packaging import version
@@ -1099,6 +1102,121 @@ def test_pyspark_processor_configuration_path_pipeline_config(
10991102
)
11001103

11011104

1105+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1106+
def test_get_codeartifact_index(pipeline_session):
1107+
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1108+
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1109+
1110+
client = boto3.client('codeartifact', region_name=REGION)
1111+
stubber = Stubber(client)
1112+
1113+
get_auth_token_response = {
1114+
"authorizationToken": "mocked_token",
1115+
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1116+
}
1117+
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1118+
stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params)
1119+
1120+
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1121+
repo_endpoint_expected_params = {
1122+
"domain": "test-domain",
1123+
"domainOwner": "012345678901",
1124+
"repository": "test-repository",
1125+
"format": "pypi"
1126+
}
1127+
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
1128+
1129+
processor = PyTorchProcessor(
1130+
role=ROLE,
1131+
instance_type="ml.m4.xlarge",
1132+
framework_version="2.0.1",
1133+
py_version="py310",
1134+
instance_count=1,
1135+
sagemaker_session=pipeline_session,
1136+
)
1137+
1138+
with stubber:
1139+
codeartifact_index = processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1140+
1141+
assert codeartifact_index == f"https://aws:mocked_token@{codeartifact_url}"
1142+
1143+
1144+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1145+
def test_get_codeartifact_index_bad_repo_arn(pipeline_session):
1146+
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain"
1147+
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1148+
1149+
client = boto3.client('codeartifact', region_name=REGION)
1150+
stubber = Stubber(client)
1151+
1152+
get_auth_token_response = {
1153+
"authorizationToken": "mocked_token",
1154+
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1155+
}
1156+
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1157+
stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params)
1158+
1159+
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1160+
repo_endpoint_expected_params = {
1161+
"domain": "test-domain",
1162+
"domainOwner": "012345678901",
1163+
"repository": "test-repository",
1164+
"format": "pypi"
1165+
}
1166+
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
1167+
1168+
processor = PyTorchProcessor(
1169+
role=ROLE,
1170+
instance_type="ml.m4.xlarge",
1171+
framework_version="2.0.1",
1172+
py_version="py310",
1173+
instance_count=1,
1174+
sagemaker_session=pipeline_session,
1175+
)
1176+
1177+
with stubber:
1178+
with pytest.raises(ValueError):
1179+
processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1180+
1181+
1182+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1183+
def test_get_codeartifact_index_client_error(pipeline_session):
1184+
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1185+
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1186+
1187+
client = boto3.client('codeartifact', region_name=REGION)
1188+
stubber = Stubber(client)
1189+
1190+
get_auth_token_response = {
1191+
"authorizationToken": "mocked_token",
1192+
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1193+
}
1194+
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1195+
stubber.add_client_error("get_authorization_token", service_error_code="404", expected_params=auth_token_expected_params)
1196+
1197+
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1198+
repo_endpoint_expected_params = {
1199+
"domain": "test-domain",
1200+
"domainOwner": "012345678901",
1201+
"repository": "test-repository",
1202+
"format": "pypi"
1203+
}
1204+
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
1205+
1206+
processor = PyTorchProcessor(
1207+
role=ROLE,
1208+
instance_type="ml.m4.xlarge",
1209+
framework_version="2.0.1",
1210+
py_version="py310",
1211+
instance_count=1,
1212+
sagemaker_session=pipeline_session,
1213+
)
1214+
1215+
with stubber:
1216+
with pytest.raises(RuntimeError):
1217+
processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1218+
1219+
11021220
def _get_script_processor(sagemaker_session):
11031221
return ScriptProcessor(
11041222
role=ROLE,

0 commit comments

Comments
 (0)