1515import datetime
1616import io
1717import logging
18+ import os
1819
1920import pytest
2021import six
2324
2425import sagemaker
2526from sagemaker import s3_input , Session , get_execution_role
26- from sagemaker .session import _tuning_job_status , _transform_job_status , _train_done
27+ from sagemaker .session import (
28+ _tuning_job_status ,
29+ _transform_job_status ,
30+ _train_done ,
31+ NOTEBOOK_METADATA_FILE ,
32+ )
2733from sagemaker .tuner import WarmStartConfig , WarmStartTypes
2834
2935STATIC_HPs = {"feature_dim" : "784" }
@@ -47,6 +53,18 @@ def boto_session():
4753 return boto_session
4854
4955
56+ def mock_exists (filepath_to_mock , exists_result ):
57+ unmocked_exists = os .path .exists
58+
59+ def side_effect (filepath ):
60+ if filepath == filepath_to_mock :
61+ return exists_result
62+ else :
63+ return unmocked_exists (filepath )
64+
65+ return Mock (side_effect = side_effect )
66+
67+
5068def test_get_execution_role ():
5169 session = Mock ()
5270 session .get_caller_identity_arn .return_value = "arn:aws:iam::369233609183:role/SageMakerRole"
@@ -86,6 +104,51 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
86104 assert "ValueError: The current AWS identity is not a role" in str (error )
87105
88106
107+ @patch ("six.moves.builtins.open" , mock_open (read_data = '{"ResourceName": "SageMakerInstance"}' ))
108+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , True ))
109+ def test_get_caller_identity_arn_from_describe_notebook_instance (boto_session ):
110+ sess = Session (boto_session )
111+ expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
112+ sess .sagemaker_client .describe_notebook_instance .return_value = {"RoleArn" : expected_role }
113+
114+ actual = sess .get_caller_identity_arn ()
115+
116+ assert actual == expected_role
117+ sess .sagemaker_client .describe_notebook_instance .assert_called_once_with (
118+ NotebookInstanceName = "SageMakerInstance"
119+ )
120+
121+
122+ @patch ("six.moves.builtins.open" , mock_open (read_data = '{"ResourceName": "SageMakerInstance"}' ))
123+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , True ))
124+ def test_get_caller_identity_arn_from_a_role_after_describe_notebook_exception (boto_session ):
125+ sess = Session (boto_session )
126+ exception = ClientError (
127+ {"Error" : {"Code" : "ValidationException" , "Message" : "RecordNotFound" }}, "Operation"
128+ )
129+ sess .sagemaker_client .describe_notebook_instance .side_effect = exception
130+
131+ arn = (
132+ "arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122"
133+ )
134+ sess .boto_session .client ("sts" , endpoint_url = STS_ENDPOINT ).get_caller_identity .return_value = {
135+ "Arn" : arn
136+ }
137+
138+ expected_role = "arn:aws:iam::369233609183:role/SageMakerRole"
139+ sess .boto_session .client ("iam" ).get_role .return_value = {"Role" : {"Arn" : expected_role }}
140+
141+ with patch ("logging.Logger.warning" ) as mock_logger :
142+ actual = sess .get_caller_identity_arn ()
143+ mock_logger .assert_called_once ()
144+
145+ sess .sagemaker_client .describe_notebook_instance .assert_called_once_with (
146+ NotebookInstanceName = "SageMakerInstance"
147+ )
148+ assert actual == expected_role
149+
150+
151+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
89152def test_get_caller_identity_arn_from_an_user (boto_session ):
90153 sess = Session (boto_session )
91154 arn = "arn:aws:iam::369233609183:user/mia"
@@ -98,6 +161,7 @@ def test_get_caller_identity_arn_from_an_user(boto_session):
98161 assert actual == "arn:aws:iam::369233609183:user/mia"
99162
100163
164+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
101165def test_get_caller_identity_arn_from_an_user_without_permissions (boto_session ):
102166 sess = Session (boto_session )
103167 arn = "arn:aws:iam::369233609183:user/mia"
@@ -112,6 +176,7 @@ def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
112176 mock_logger .assert_called_once ()
113177
114178
179+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
115180def test_get_caller_identity_arn_from_a_role (boto_session ):
116181 sess = Session (boto_session )
117182 arn = (
@@ -128,6 +193,7 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
128193 assert actual == expected_role
129194
130195
196+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
131197def test_get_caller_identity_arn_from_a_execution_role (boto_session ):
132198 sess = Session (boto_session )
133199 arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker"
@@ -143,6 +209,7 @@ def test_get_caller_identity_arn_from_a_execution_role(boto_session):
143209 )
144210
145211
212+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
146213def test_get_caller_identity_arn_from_role_with_path (boto_session ):
147214 sess = Session (boto_session )
148215 arn_prefix = "arn:aws:iam::369233609183:role"
0 commit comments