Skip to content

Commit 4969bfb

Browse files
akuma12akrishna1995
authored andcommitted
Add unit test for _get_codeartifact_index
1 parent 53a07bd commit 4969bfb

File tree

2 files changed

+134
-10
lines changed

2 files changed

+134
-10
lines changed

src/sagemaker/processing.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pathlib
2323
import logging
2424
from textwrap import dedent
25-
from typing import Dict, List, Optional, Union
25+
from typing import Any, Dict, List, Optional, Union
2626
from copy import copy
2727
import re
2828

@@ -1845,14 +1845,18 @@ def _pack_and_upload_code(
18451845

18461846
return s3_runproc_sh, inputs, job_name
18471847

1848-
def _get_codeartifact_index(self, codeartifact_repo_arn: str):
1848+
def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_client: Any = None):
18491849
"""
18501850
Build the authenticated codeartifact index url based on the arn provided
18511851
via codeartifact_repo_arn property following the form
18521852
# `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
18531853
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
18541854
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
1855-
:return: authenticated codeartifact index url
1855+
Args:
1856+
codeartifact_repo_arn: arn of the codeartifact repository
1857+
codeartifact_client: boto3 client for codeartifact (used for testing)
1858+
Returns:
1859+
authenticated codeartifact index url
18561860
"""
18571861

18581862
arn_regex = (
@@ -1861,7 +1865,7 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str):
18611865
)
18621866
m = re.match(arn_regex, codeartifact_repo_arn)
18631867
if not m:
1864-
raise Exception("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
1868+
raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
18651869
domain = m.group("domain")
18661870
owner = m.group("account")
18671871
repository = m.group("repository")
@@ -1876,10 +1880,12 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str):
18761880
region,
18771881
)
18781882
try:
1879-
client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region)
1880-
auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner)
1883+
if not codeartifact_client:
1884+
codeartifact_client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region)
1885+
1886+
auth_token_response = codeartifact_client.get_authorization_token(domain=domain, domainOwner=owner)
18811887
token = auth_token_response["authorizationToken"]
1882-
endpoint_response = client.get_repository_endpoint(
1888+
endpoint_response = codeartifact_client.get_repository_endpoint(
18831889
domain=domain, domainOwner=owner, repository=repository, format="pypi"
18841890
)
18851891
unauthenticated_index = endpoint_response["repositoryEndpoint"]
@@ -1892,9 +1898,9 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str):
18921898
unauthenticated_index,
18931899
),
18941900
)
1895-
except Exception:
1896-
logger.error("failed to configure pip to use codeartifact")
1897-
raise Exception("failed to configure pip to use codeartifact")
1901+
except Exception as e:
1902+
logger.error("failed to configure pip to use codeartifact: %s", e, exc_info=True)
1903+
raise RuntimeError("failed to configure pip to use codeartifact")
18981904

18991905
def _generate_framework_script(
19001906
self, user_script: str, codeartifact_repo_arn: str = None

tests/unit/test_processing.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from __future__ import absolute_import
1414

1515
import copy
16+
import datetime
1617

18+
import boto3
19+
from botocore.stub import Stubber
1720
import pytest
1821
from mock import Mock, patch, MagicMock
1922
from packaging import version
@@ -1102,6 +1105,121 @@ def test_pyspark_processor_configuration_path_pipeline_config(
11021105
)
11031106

11041107

1108+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1109+
def test_get_codeartifact_index(pipeline_session):
1110+
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1111+
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1112+
1113+
client = boto3.client('codeartifact', region_name=REGION)
1114+
stubber = Stubber(client)
1115+
1116+
get_auth_token_response = {
1117+
"authorizationToken": "mocked_token",
1118+
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1119+
}
1120+
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1121+
stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params)
1122+
1123+
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1124+
repo_endpoint_expected_params = {
1125+
"domain": "test-domain",
1126+
"domainOwner": "012345678901",
1127+
"repository": "test-repository",
1128+
"format": "pypi"
1129+
}
1130+
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
1131+
1132+
processor = PyTorchProcessor(
1133+
role=ROLE,
1134+
instance_type="ml.m4.xlarge",
1135+
framework_version="2.0.1",
1136+
py_version="py310",
1137+
instance_count=1,
1138+
sagemaker_session=pipeline_session,
1139+
)
1140+
1141+
with stubber:
1142+
codeartifact_index = processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1143+
1144+
assert codeartifact_index == f"https://aws:mocked_token@{codeartifact_url}"
1145+
1146+
1147+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1148+
def test_get_codeartifact_index_bad_repo_arn(pipeline_session):
1149+
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain"
1150+
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1151+
1152+
client = boto3.client('codeartifact', region_name=REGION)
1153+
stubber = Stubber(client)
1154+
1155+
get_auth_token_response = {
1156+
"authorizationToken": "mocked_token",
1157+
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1158+
}
1159+
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1160+
stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params)
1161+
1162+
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1163+
repo_endpoint_expected_params = {
1164+
"domain": "test-domain",
1165+
"domainOwner": "012345678901",
1166+
"repository": "test-repository",
1167+
"format": "pypi"
1168+
}
1169+
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
1170+
1171+
processor = PyTorchProcessor(
1172+
role=ROLE,
1173+
instance_type="ml.m4.xlarge",
1174+
framework_version="2.0.1",
1175+
py_version="py310",
1176+
instance_count=1,
1177+
sagemaker_session=pipeline_session,
1178+
)
1179+
1180+
with stubber:
1181+
with pytest.raises(ValueError):
1182+
processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1183+
1184+
1185+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1186+
def test_get_codeartifact_index_client_error(pipeline_session):
1187+
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1188+
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1189+
1190+
client = boto3.client('codeartifact', region_name=REGION)
1191+
stubber = Stubber(client)
1192+
1193+
get_auth_token_response = {
1194+
"authorizationToken": "mocked_token",
1195+
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1196+
}
1197+
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1198+
stubber.add_client_error("get_authorization_token", service_error_code="404", expected_params=auth_token_expected_params)
1199+
1200+
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1201+
repo_endpoint_expected_params = {
1202+
"domain": "test-domain",
1203+
"domainOwner": "012345678901",
1204+
"repository": "test-repository",
1205+
"format": "pypi"
1206+
}
1207+
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
1208+
1209+
processor = PyTorchProcessor(
1210+
role=ROLE,
1211+
instance_type="ml.m4.xlarge",
1212+
framework_version="2.0.1",
1213+
py_version="py310",
1214+
instance_count=1,
1215+
sagemaker_session=pipeline_session,
1216+
)
1217+
1218+
with stubber:
1219+
with pytest.raises(RuntimeError):
1220+
processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1221+
1222+
11051223
def _get_script_processor(sagemaker_session):
11061224
return ScriptProcessor(
11071225
role=ROLE,

0 commit comments

Comments
 (0)