|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 | 15 | import copy
|
| 16 | +import datetime |
16 | 17 |
|
| 18 | +import boto3 |
| 19 | +from botocore.stub import Stubber |
17 | 20 | import pytest
|
18 | 21 | from mock import Mock, patch, MagicMock
|
19 | 22 | from packaging import version
|
@@ -1099,6 +1102,121 @@ def test_pyspark_processor_configuration_path_pipeline_config(
|
1099 | 1102 | )
|
1100 | 1103 |
|
1101 | 1104 |
|
| 1105 | +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) |
| 1106 | +def test_get_codeartifact_index(pipeline_session): |
| 1107 | + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" |
| 1108 | + codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" |
| 1109 | + |
| 1110 | + client = boto3.client('codeartifact', region_name=REGION) |
| 1111 | + stubber = Stubber(client) |
| 1112 | + |
| 1113 | + get_auth_token_response = { |
| 1114 | + "authorizationToken": "mocked_token", |
| 1115 | + "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) |
| 1116 | + } |
| 1117 | + auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} |
| 1118 | + stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params) |
| 1119 | + |
| 1120 | + get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} |
| 1121 | + repo_endpoint_expected_params = { |
| 1122 | + "domain": "test-domain", |
| 1123 | + "domainOwner": "012345678901", |
| 1124 | + "repository": "test-repository", |
| 1125 | + "format": "pypi" |
| 1126 | + } |
| 1127 | + stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) |
| 1128 | + |
| 1129 | + processor = PyTorchProcessor( |
| 1130 | + role=ROLE, |
| 1131 | + instance_type="ml.m4.xlarge", |
| 1132 | + framework_version="2.0.1", |
| 1133 | + py_version="py310", |
| 1134 | + instance_count=1, |
| 1135 | + sagemaker_session=pipeline_session, |
| 1136 | + ) |
| 1137 | + |
| 1138 | + with stubber: |
| 1139 | + codeartifact_index = processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) |
| 1140 | + |
| 1141 | + assert codeartifact_index == f"https://aws:mocked_token@{codeartifact_url}" |
| 1142 | + |
| 1143 | + |
| 1144 | +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) |
| 1145 | +def test_get_codeartifact_index_bad_repo_arn(pipeline_session): |
| 1146 | + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain" |
| 1147 | + codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" |
| 1148 | + |
| 1149 | + client = boto3.client('codeartifact', region_name=REGION) |
| 1150 | + stubber = Stubber(client) |
| 1151 | + |
| 1152 | + get_auth_token_response = { |
| 1153 | + "authorizationToken": "mocked_token", |
| 1154 | + "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) |
| 1155 | + } |
| 1156 | + auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} |
| 1157 | + stubber.add_response("get_authorization_token", get_auth_token_response, auth_token_expected_params) |
| 1158 | + |
| 1159 | + get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} |
| 1160 | + repo_endpoint_expected_params = { |
| 1161 | + "domain": "test-domain", |
| 1162 | + "domainOwner": "012345678901", |
| 1163 | + "repository": "test-repository", |
| 1164 | + "format": "pypi" |
| 1165 | + } |
| 1166 | + stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) |
| 1167 | + |
| 1168 | + processor = PyTorchProcessor( |
| 1169 | + role=ROLE, |
| 1170 | + instance_type="ml.m4.xlarge", |
| 1171 | + framework_version="2.0.1", |
| 1172 | + py_version="py310", |
| 1173 | + instance_count=1, |
| 1174 | + sagemaker_session=pipeline_session, |
| 1175 | + ) |
| 1176 | + |
| 1177 | + with stubber: |
| 1178 | + with pytest.raises(ValueError): |
| 1179 | + processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) |
| 1180 | + |
| 1181 | + |
| 1182 | +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) |
| 1183 | +def test_get_codeartifact_index_client_error(pipeline_session): |
| 1184 | + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" |
| 1185 | + codeartifact_url = "test-domain-012345678901.d.codeartifact.us-west-2.amazonaws.com/pypi/test-repository/simple/" |
| 1186 | + |
| 1187 | + client = boto3.client('codeartifact', region_name=REGION) |
| 1188 | + stubber = Stubber(client) |
| 1189 | + |
| 1190 | + get_auth_token_response = { |
| 1191 | + "authorizationToken": "mocked_token", |
| 1192 | + "expiration": datetime.datetime(2045, 1, 1, 0, 0, 0) |
| 1193 | + } |
| 1194 | + auth_token_expected_params = {"domain": "test-domain", "domainOwner": "012345678901"} |
| 1195 | + stubber.add_client_error("get_authorization_token", service_error_code="404", expected_params=auth_token_expected_params) |
| 1196 | + |
| 1197 | + get_repo_endpoint_response = {"repositoryEndpoint": f"https://{codeartifact_url}"} |
| 1198 | + repo_endpoint_expected_params = { |
| 1199 | + "domain": "test-domain", |
| 1200 | + "domainOwner": "012345678901", |
| 1201 | + "repository": "test-repository", |
| 1202 | + "format": "pypi" |
| 1203 | + } |
| 1204 | + stubber.add_response("get_repository_endpoint", get_repo_endpoint_response, repo_endpoint_expected_params) |
| 1205 | + |
| 1206 | + processor = PyTorchProcessor( |
| 1207 | + role=ROLE, |
| 1208 | + instance_type="ml.m4.xlarge", |
| 1209 | + framework_version="2.0.1", |
| 1210 | + py_version="py310", |
| 1211 | + instance_count=1, |
| 1212 | + sagemaker_session=pipeline_session, |
| 1213 | + ) |
| 1214 | + |
| 1215 | + with stubber: |
| 1216 | + with pytest.raises(RuntimeError): |
| 1217 | + processor._get_codeartifact_index(codeartifact_repo_arn=codeartifact_repo_arn, codeartifact_client=client) |
| 1218 | + |
| 1219 | + |
1102 | 1220 | def _get_script_processor(sagemaker_session):
|
1103 | 1221 | return ScriptProcessor(
|
1104 | 1222 | role=ROLE,
|
|
0 commit comments