13
13
from __future__ import absolute_import
14
14
15
15
import copy
16
- import datetime
17
16
18
- import boto3
19
- from botocore .stub import Stubber
20
17
import pytest
21
18
from mock import Mock , patch , MagicMock
22
19
from packaging import version
20
+ from textwrap import dedent
23
21
24
22
from sagemaker import LocalSession
25
23
from sagemaker .dataset_definition .inputs import (
@@ -1106,28 +1104,8 @@ def test_pyspark_processor_configuration_path_pipeline_config(
1106
1104
1107
1105
1108
1106
@patch ("sagemaker.workflow.utilities._pipeline_config" , MOCKED_PIPELINE_CONFIG )
1109
- def test_get_codeartifact_index (pipeline_session ):
1107
+ def test_get_codeartifact_command (pipeline_session ):
1110
1108
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1111
- codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1112
-
1113
- client = boto3 .client ('codeartifact' , region_name = REGION )
1114
- stubber = Stubber (client )
1115
-
1116
- get_auth_token_response = {
1117
- "authorizationToken" : "mocked_token" ,
1118
- "expiration" : datetime .datetime (2045 , 1 , 1 , 0 , 0 , 0 )
1119
- }
1120
- auth_token_expected_params = {"domain" : "test-domain" , "domainOwner" : "012345678901" }
1121
- stubber .add_response ("get_authorization_token" , get_auth_token_response , auth_token_expected_params )
1122
-
1123
- get_repo_endpoint_response = {"repositoryEndpoint" : f"https://{ codeartifact_url } " }
1124
- repo_endpoint_expected_params = {
1125
- "domain" : "test-domain" ,
1126
- "domainOwner" : "012345678901" ,
1127
- "repository" : "test-repository" ,
1128
- "format" : "pypi"
1129
- }
1130
- stubber .add_response ("get_repository_endpoint" , get_repo_endpoint_response , repo_endpoint_expected_params )
1131
1109
1132
1110
processor = PyTorchProcessor (
1133
1111
role = ROLE ,
@@ -1138,35 +1116,14 @@ def test_get_codeartifact_index(pipeline_session):
1138
1116
sagemaker_session = pipeline_session ,
1139
1117
)
1140
1118
1141
- with stubber :
1142
- codeartifact_index = processor ._get_codeartifact_index (codeartifact_repo_arn = codeartifact_repo_arn , codeartifact_client = client )
1119
+ codeartifact_command = processor ._get_codeartifact_command (codeartifact_repo_arn = codeartifact_repo_arn )
1143
1120
1144
- assert codeartifact_index == f"https:// aws:mocked_token@ { codeartifact_url } "
1121
+ assert codeartifact_command == " aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2 "
1145
1122
1146
1123
1147
1124
@patch ("sagemaker.workflow.utilities._pipeline_config" , MOCKED_PIPELINE_CONFIG )
1148
- def test_get_codeartifact_index_bad_repo_arn (pipeline_session ):
1125
+ def test_get_codeartifact_command_bad_repo_arn (pipeline_session ):
1149
1126
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain"
1150
- codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1151
-
1152
- client = boto3 .client ('codeartifact' , region_name = REGION )
1153
- stubber = Stubber (client )
1154
-
1155
- get_auth_token_response = {
1156
- "authorizationToken" : "mocked_token" ,
1157
- "expiration" : datetime .datetime (2045 , 1 , 1 , 0 , 0 , 0 )
1158
- }
1159
- auth_token_expected_params = {"domain" : "test-domain" , "domainOwner" : "012345678901" }
1160
- stubber .add_response ("get_authorization_token" , get_auth_token_response , auth_token_expected_params )
1161
-
1162
- get_repo_endpoint_response = {"repositoryEndpoint" : f"https://{ codeartifact_url } " }
1163
- repo_endpoint_expected_params = {
1164
- "domain" : "test-domain" ,
1165
- "domainOwner" : "012345678901" ,
1166
- "repository" : "test-repository" ,
1167
- "format" : "pypi"
1168
- }
1169
- stubber .add_response ("get_repository_endpoint" , get_repo_endpoint_response , repo_endpoint_expected_params )
1170
1127
1171
1128
processor = PyTorchProcessor (
1172
1129
role = ROLE ,
@@ -1177,35 +1134,52 @@ def test_get_codeartifact_index_bad_repo_arn(pipeline_session):
1177
1134
sagemaker_session = pipeline_session ,
1178
1135
)
1179
1136
1180
- with stubber :
1181
- with pytest .raises (ValueError ):
1182
- processor ._get_codeartifact_index (codeartifact_repo_arn = codeartifact_repo_arn , codeartifact_client = client )
1183
-
1137
+ with pytest .raises (ValueError ):
1138
+ processor ._get_codeartifact_command (codeartifact_repo_arn = codeartifact_repo_arn )
1184
1139
1185
1140
@patch ("sagemaker.workflow.utilities._pipeline_config" , MOCKED_PIPELINE_CONFIG )
1186
- def test_get_codeartifact_index_client_error (pipeline_session ):
1187
- codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1188
- codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/"
1189
-
1190
- client = boto3 .client ('codeartifact' , region_name = REGION )
1191
- stubber = Stubber (client )
1192
-
1193
- get_auth_token_response = {
1194
- "authorizationToken" : "mocked_token" ,
1195
- "expiration" : datetime .datetime (2045 , 1 , 1 , 0 , 0 , 0 )
1196
- }
1197
- auth_token_expected_params = {"domain" : "test-domain" , "domainOwner" : "012345678901" }
1198
- stubber .add_client_error ("get_authorization_token" , service_error_code = "404" , expected_params = auth_token_expected_params )
1199
-
1200
- get_repo_endpoint_response = {"repositoryEndpoint" : f"https://{ codeartifact_url } " }
1201
- repo_endpoint_expected_params = {
1202
- "domain" : "test-domain" ,
1203
- "domainOwner" : "012345678901" ,
1204
- "repository" : "test-repository" ,
1205
- "format" : "pypi"
1206
- }
1207
- stubber .add_response ("get_repository_endpoint" , get_repo_endpoint_response , repo_endpoint_expected_params )
1141
+ def test_generate_framework_script (pipeline_session ):
1142
+ processor = PyTorchProcessor (
1143
+ role = ROLE ,
1144
+ instance_type = "ml.m4.xlarge" ,
1145
+ framework_version = "2.0.1" ,
1146
+ py_version = "py310" ,
1147
+ instance_count = 1 ,
1148
+ sagemaker_session = pipeline_session ,
1149
+ )
1150
+
1151
+ framework_script = processor ._generate_framework_script (user_script = "process.py" )
1208
1152
1153
+ assert framework_script == dedent (
1154
+ """\
1155
+ #!/bin/bash
1156
+
1157
+ cd /opt/ml/processing/input/code/
1158
+ tar -xzf sourcedir.tar.gz
1159
+
1160
+ # Exit on any error. SageMaker uses error code to mark failed job.
1161
+ set -e
1162
+
1163
+ if [[ -f 'requirements.txt' ]]; then
1164
+ # Optionally log into CodeArtifact
1165
+ if ! hash aws 2>/dev/null; then
1166
+ echo "AWS CLI is not installed. Skipping CodeArtifact login."
1167
+ else
1168
+ echo 'CodeArtifact repository not specified. Skipping login.'
1169
+ fi
1170
+
1171
+ # Some py3 containers has typing, which may breaks pip install
1172
+ pip uninstall --yes typing
1173
+
1174
+ pip install -r requirements.txt
1175
+ fi
1176
+
1177
+ python process.py "$@"
1178
+ """
1179
+ )
1180
+
1181
+ @patch ("sagemaker.workflow.utilities._pipeline_config" , MOCKED_PIPELINE_CONFIG )
1182
+ def test_generate_framework_script_with_codeartifact (pipeline_session ):
1209
1183
processor = PyTorchProcessor (
1210
1184
role = ROLE ,
1211
1185
instance_type = "ml.m4.xlarge" ,
@@ -1215,10 +1189,38 @@ def test_get_codeartifact_index_client_error(pipeline_session):
1215
1189
sagemaker_session = pipeline_session ,
1216
1190
)
1217
1191
1218
- with stubber :
1219
- with pytest .raises (RuntimeError ):
1220
- processor ._get_codeartifact_index (codeartifact_repo_arn = codeartifact_repo_arn , codeartifact_client = client )
1192
+ framework_script = processor ._generate_framework_script (
1193
+ user_script = "process.py" ,
1194
+ codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1195
+ )
1196
+
1197
+ assert framework_script == dedent (
1198
+ """\
1199
+ #!/bin/bash
1200
+
1201
+ cd /opt/ml/processing/input/code/
1202
+ tar -xzf sourcedir.tar.gz
1203
+
1204
+ # Exit on any error. SageMaker uses error code to mark failed job.
1205
+ set -e
1221
1206
1207
+ if [[ -f 'requirements.txt' ]]; then
1208
+ # Optionally log into CodeArtifact
1209
+ if ! hash aws 2>/dev/null; then
1210
+ echo "AWS CLI is not installed. Skipping CodeArtifact login."
1211
+ else
1212
+ "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2"
1213
+ fi
1214
+
1215
+ # Some py3 containers has typing, which may breaks pip install
1216
+ pip uninstall --yes typing
1217
+
1218
+ pip install -r requirements.txt
1219
+ fi
1220
+
1221
+ python process.py "$@"
1222
+ """
1223
+ )
1222
1224
1223
1225
def _get_script_processor (sagemaker_session ):
1224
1226
return ScriptProcessor (
0 commit comments