Skip to content

Commit fee0a83

Browse files
akuma12akrishna1995
authored andcommitted
Convert CodeArtifact integration to simply generate an AWS CLI command to log into CodeArtifact
1 parent fa37d4c commit fee0a83

File tree

2 files changed

+106
-112
lines changed

2 files changed

+106
-112
lines changed

src/sagemaker/processing.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pathlib
2323
import logging
2424
from textwrap import dedent
25-
from typing import Any, Dict, List, Optional, Union
25+
from typing import Dict, List, Optional, Union
2626
from copy import copy
2727
import re
2828

@@ -1845,19 +1845,18 @@ def _pack_and_upload_code(
18451845

18461846
return s3_runproc_sh, inputs, job_name
18471847

1848-
def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_client: Any = None):
1849-
"""Build an authenticated codeartifact index url based on the arn provided.
1848+
def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str:
1849+
"""Build an AWS CLI CodeArtifact command to configure pip.
18501850
18511851
The codeartifact_repo_arn property must follow the form
18521852
# `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
18531853
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
18541854
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
1855-
1855+
18561856
Args:
18571857
codeartifact_repo_arn: arn of the codeartifact repository
1858-
codeartifact_client: boto3 client for codeartifact (used for testing)
18591858
Returns:
1860-
authenticated codeartifact index url
1859+
codeartifact command string
18611860
"""
18621861

18631862
arn_regex = (
@@ -1866,7 +1865,9 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_clien
18661865
)
18671866
m = re.match(arn_regex, codeartifact_repo_arn)
18681867
if not m:
1869-
raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
1868+
raise ValueError(
1869+
"invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)
1870+
)
18701871
domain = m.group("domain")
18711872
owner = m.group("account")
18721873
repository = m.group("repository")
@@ -1880,28 +1881,8 @@ def _get_codeartifact_index(self, codeartifact_repo_arn: str, codeartifact_clien
18801881
repository,
18811882
region,
18821883
)
1883-
try:
1884-
if not codeartifact_client:
1885-
codeartifact_client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region)
1886-
1887-
auth_token_response = codeartifact_client.get_authorization_token(domain=domain, domainOwner=owner)
1888-
token = auth_token_response["authorizationToken"]
1889-
endpoint_response = codeartifact_client.get_repository_endpoint(
1890-
domain=domain, domainOwner=owner, repository=repository, format="pypi"
1891-
)
1892-
unauthenticated_index = endpoint_response["repositoryEndpoint"]
1893-
return re.sub(
1894-
"https://",
1895-
"https://aws:{}@".format(token),
1896-
re.sub(
1897-
"{}/?$".format(repository),
1898-
"{}/simple/".format(repository),
1899-
unauthenticated_index,
1900-
),
1901-
)
1902-
except Exception as e:
1903-
logger.error("failed to configure pip to use codeartifact: %s", e, exc_info=True)
1904-
raise RuntimeError("failed to configure pip to use codeartifact")
1884+
1885+
return f"aws codeartifact login --tool pip --domain {domain} --domain-owner {owner} --repository {repository} --region {region}" # pylint: disable=line-too-long
19051886

19061887
def _generate_framework_script(
19071888
self, user_script: str, codeartifact_repo_arn: str = None
@@ -1920,10 +1901,12 @@ def _generate_framework_script(
19201901
logged into before installing dependencies (default: None).
19211902
"""
19221903
if codeartifact_repo_arn:
1923-
index = self._get_codeartifact_index(codeartifact_repo_arn)
1924-
index_option = "-i {}".format(index)
1904+
codeartifact_login_command = self._get_codeartifact_command(
1905+
codeartifact_repo_arn
1906+
)
19251907
else:
1926-
index_option = ""
1908+
codeartifact_login_command = \
1909+
"echo 'CodeArtifact repository not specified. Skipping login.'"
19271910

19281911
return dedent(
19291912
"""\
@@ -1936,16 +1919,23 @@ def _generate_framework_script(
19361919
set -e
19371920
19381921
if [[ -f 'requirements.txt' ]]; then
1922+
# Optionally log into CodeArtifact
1923+
if ! hash aws 2>/dev/null; then
1924+
echo "AWS CLI is not installed. Skipping CodeArtifact login."
1925+
else
1926+
{codeartifact_login_command}
1927+
fi
1928+
19391929
# Some py3 containers has typing, which may breaks pip install
19401930
pip uninstall --yes typing
19411931
1942-
pip install -r requirements.txt {index_option}
1932+
pip install -r requirements.txt
19431933
fi
19441934
19451935
{entry_point_command} {entry_point} "$@"
19461936
"""
19471937
).format(
1948-
index_option=index_option,
1938+
codeartifact_login_command=codeartifact_login_command,
19491939
entry_point_command=" ".join(self.command),
19501940
entry_point=user_script,
19511941
)
@@ -2039,7 +2029,9 @@ def _create_and_upload_runproc(
20392029
from sagemaker.workflow.utilities import _pipeline_config, hash_object
20402030

20412031
if _pipeline_config and _pipeline_config.pipeline_name:
2042-
runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn)
2032+
runproc_file_str = self._generate_framework_script(
2033+
user_script, codeartifact_repo_arn
2034+
)
20432035
runproc_file_hash = hash_object(runproc_file_str)
20442036
s3_uri = s3.s3_path_join(
20452037
"s3://",

tests/unit/test_processing.py

Lines changed: 79 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
from __future__ import absolute_import
1414

1515
import copy
16-
import datetime
1716

18-
import boto3
19-
from botocore.stub import Stubber
2017
import pytest
2118
from mock import Mock, patch, MagicMock
2219
from packaging import version
20+
from textwrap import dedent
2321

2422
from sagemaker import LocalSession
2523
from sagemaker.dataset_definition.inputs import (
@@ -1106,28 +1104,8 @@ def test_pyspark_processor_configuration_path_pipeline_config(
11061104

11071105

11081106
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1109-
def test_get_codeartifact_index(pipeline_session):
1107+
def test_get_codeartifact_command(pipeline_session):
11101108
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1111-
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1112-
1113-
client = boto3.client('codeartifact', region_name=REGION)
1114-
stubber = Stubber(client)
1115-
1116-
get_auth_token_response = {
1117-
"authorizationToken": "mocked_token",
1118-
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1119-
}
1120-
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1121-
stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params)
1122-
1123-
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1124-
repo_endpoint_expected_params = {
1125-
"domain": "test-domain",
1126-
"domainOwner": "012345678901",
1127-
"repository": "test-repository",
1128-
"format": "pypi"
1129-
}
1130-
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
11311109

11321110
processor = PyTorchProcessor(
11331111
role=ROLE,
@@ -1138,35 +1116,14 @@ def test_get_codeartifact_index(pipeline_session):
11381116
sagemaker_session=pipeline_session,
11391117
)
11401118

1141-
with stubber:
1142-
codeartifact_index = processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1119+
codeartifact_command = processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn)
11431120

1144-
assert codeartifact_index == f"https://aws:mocked_token@{codeartifact_url}"
1121+
assert codeartifact_command == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2"
11451122

11461123

11471124
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1148-
def test_get_codeartifact_index_bad_repo_arn(pipeline_session):
1125+
def test_get_codeartifact_command_bad_repo_arn(pipeline_session):
11491126
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain"
1150-
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1151-
1152-
client = boto3.client('codeartifact', region_name=REGION)
1153-
stubber = Stubber(client)
1154-
1155-
get_auth_token_response = {
1156-
"authorizationToken": "mocked_token",
1157-
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1158-
}
1159-
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1160-
stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params)
1161-
1162-
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1163-
repo_endpoint_expected_params = {
1164-
"domain": "test-domain",
1165-
"domainOwner": "012345678901",
1166-
"repository": "test-repository",
1167-
"format": "pypi"
1168-
}
1169-
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
11701127

11711128
processor = PyTorchProcessor(
11721129
role=ROLE,
@@ -1177,35 +1134,52 @@ def test_get_codeartifact_index_bad_repo_arn(pipeline_session):
11771134
sagemaker_session=pipeline_session,
11781135
)
11791136

1180-
with stubber:
1181-
with pytest.raises(ValueError):
1182-
processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1183-
1137+
with pytest.raises(ValueError):
1138+
processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn)
11841139

11851140
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1186-
def test_get_codeartifact_index_client_error(pipeline_session):
1187-
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1188-
codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1189-
1190-
client = boto3.client('codeartifact', region_name=REGION)
1191-
stubber = Stubber(client)
1192-
1193-
get_auth_token_response = {
1194-
"authorizationToken": "mocked_token",
1195-
"expiration": datetime.datetime(2045, 1, 1, 0, 0, 0)
1196-
}
1197-
auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"}
1198-
stubber.add_client_error("get_authorization_token", service_error_code="404", expected_params=auth_token_expected_params)
1199-
1200-
get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"}
1201-
repo_endpoint_expected_params = {
1202-
"domain": "test-domain",
1203-
"domainOwner": "012345678901",
1204-
"repository": "test-repository",
1205-
"format": "pypi"
1206-
}
1207-
stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params)
1141+
def test_generate_framework_script(pipeline_session):
1142+
processor = PyTorchProcessor(
1143+
role=ROLE,
1144+
instance_type="ml.m4.xlarge",
1145+
framework_version="2.0.1",
1146+
py_version="py310",
1147+
instance_count=1,
1148+
sagemaker_session=pipeline_session,
1149+
)
1150+
1151+
framework_script = processor._generate_framework_script(user_script="process.py")
12081152

1153+
assert framework_script == dedent(
1154+
"""\
1155+
#!/bin/bash
1156+
1157+
cd /opt/ml/processing/input/code/
1158+
tar -xzf sourcedir.tar.gz
1159+
1160+
# Exit on any error. SageMaker uses error code to mark failed job.
1161+
set -e
1162+
1163+
if [[ -f 'requirements.txt' ]]; then
1164+
# Optionally log into CodeArtifact
1165+
if ! hash aws 2>/dev/null; then
1166+
echo "AWS CLI is not installed. Skipping CodeArtifact login."
1167+
else
1168+
echo 'CodeArtifact repository not specified. Skipping login.'
1169+
fi
1170+
1171+
# Some py3 containers has typing, which may breaks pip install
1172+
pip uninstall --yes typing
1173+
1174+
pip install -r requirements.txt
1175+
fi
1176+
1177+
python process.py "$@"
1178+
"""
1179+
)
1180+
1181+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1182+
def test_generate_framework_script_with_codeartifact(pipeline_session):
12091183
processor = PyTorchProcessor(
12101184
role=ROLE,
12111185
instance_type="ml.m4.xlarge",
@@ -1215,10 +1189,38 @@ def test_get_codeartifact_index_client_error(pipeline_session):
12151189
sagemaker_session=pipeline_session,
12161190
)
12171191

1218-
with stubber:
1219-
with pytest.raises(RuntimeError):
1220-
processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client)
1192+
framework_script = processor._generate_framework_script(
1193+
user_script="process.py",
1194+
codeartifact_repo_arn="arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1195+
)
1196+
1197+
assert framework_script == dedent(
1198+
"""\
1199+
#!/bin/bash
1200+
1201+
cd /opt/ml/processing/input/code/
1202+
tar -xzf sourcedir.tar.gz
1203+
1204+
# Exit on any error. SageMaker uses error code to mark failed job.
1205+
set -e
12211206
1207+
if [[ -f 'requirements.txt' ]]; then
1208+
# Optionally log into CodeArtifact
1209+
if ! hash aws 2>/dev/null; then
1210+
echo "AWS CLI is not installed. Skipping CodeArtifact login."
1211+
else
1212+
"aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2"
1213+
fi
1214+
1215+
# Some py3 containers has typing, which may breaks pip install
1216+
pip uninstall --yes typing
1217+
1218+
pip install -r requirements.txt
1219+
fi
1220+
1221+
python process.py "$@"
1222+
"""
1223+
)
12221224

12231225
def _get_script_processor(sagemaker_session):
12241226
return ScriptProcessor(

0 commit comments

Comments
 (0)