15
15
import datetime
16
16
import io
17
17
import logging
18
+ import os
18
19
19
20
import pytest
20
21
import six
23
24
24
25
import sagemaker
25
26
from sagemaker import s3_input , Session , get_execution_role
26
- from sagemaker .session import _tuning_job_status , _transform_job_status , _train_done
27
+ from sagemaker .session import (
28
+ _tuning_job_status ,
29
+ _transform_job_status ,
30
+ _train_done ,
31
+ NOTEBOOK_METADATA_FILE ,
32
+ )
27
33
from sagemaker .tuner import WarmStartConfig , WarmStartTypes
28
34
29
35
STATIC_HPs = {"feature_dim" : "784" }
@@ -47,6 +53,18 @@ def boto_session():
47
53
return boto_session
48
54
49
55
56
+ def mock_exists (filepath_to_mock , exists_result ):
57
+ unmocked_exists = os .path .exists
58
+
59
+ def side_effect (filepath ):
60
+ if filepath == filepath_to_mock :
61
+ return exists_result
62
+ else :
63
+ return unmocked_exists (filepath )
64
+
65
+ return Mock (side_effect = side_effect )
66
+
67
+
50
68
def test_get_execution_role ():
51
69
session = Mock ()
52
70
session .get_caller_identity_arn .return_value = "arn:aws:iam::369233609183:role/SageMakerRole"
@@ -86,6 +104,51 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
86
104
assert "ValueError: The current AWS identity is not a role" in str (error )
87
105
88
106
107
+ @patch ("six.moves.builtins.open" , mock_open (read_data = '{"ResourceName": "SageMakerInstance"}' ))
108
+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , True ))
109
+ def test_get_caller_identity_arn_from_describe_notebook_instance (boto_session ):
110
+ sess = Session (boto_session )
111
+ expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
112
+ sess .sagemaker_client .describe_notebook_instance .return_value = {"RoleArn" : expected_role }
113
+
114
+ actual = sess .get_caller_identity_arn ()
115
+
116
+ assert actual == expected_role
117
+ sess .sagemaker_client .describe_notebook_instance .assert_called_once_with (
118
+ NotebookInstanceName = "SageMakerInstance"
119
+ )
120
+
121
+
122
+ @patch ("six.moves.builtins.open" , mock_open (read_data = '{"ResourceName": "SageMakerInstance"}' ))
123
+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , True ))
124
+ def test_get_caller_identity_arn_from_a_role_after_describe_notebook_exception (boto_session ):
125
+ sess = Session (boto_session )
126
+ exception = ClientError (
127
+ {"Error" : {"Code" : "ValidationException" , "Message" : "RecordNotFound" }}, "Operation"
128
+ )
129
+ sess .sagemaker_client .describe_notebook_instance .side_effect = exception
130
+
131
+ arn = (
132
+ "arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122"
133
+ )
134
+ sess .boto_session .client ("sts" , endpoint_url = STS_ENDPOINT ).get_caller_identity .return_value = {
135
+ "Arn" : arn
136
+ }
137
+
138
+ expected_role = "arn:aws:iam::369233609183:role/SageMakerRole"
139
+ sess .boto_session .client ("iam" ).get_role .return_value = {"Role" : {"Arn" : expected_role }}
140
+
141
+ with patch ("logging.Logger.warning" ) as mock_logger :
142
+ actual = sess .get_caller_identity_arn ()
143
+ mock_logger .assert_called_once ()
144
+
145
+ sess .sagemaker_client .describe_notebook_instance .assert_called_once_with (
146
+ NotebookInstanceName = "SageMakerInstance"
147
+ )
148
+ assert actual == expected_role
149
+
150
+
151
+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
89
152
def test_get_caller_identity_arn_from_an_user (boto_session ):
90
153
sess = Session (boto_session )
91
154
arn = "arn:aws:iam::369233609183:user/mia"
@@ -98,6 +161,7 @@ def test_get_caller_identity_arn_from_an_user(boto_session):
98
161
assert actual == "arn:aws:iam::369233609183:user/mia"
99
162
100
163
164
+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
101
165
def test_get_caller_identity_arn_from_an_user_without_permissions (boto_session ):
102
166
sess = Session (boto_session )
103
167
arn = "arn:aws:iam::369233609183:user/mia"
@@ -112,6 +176,7 @@ def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
112
176
mock_logger .assert_called_once ()
113
177
114
178
179
+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
115
180
def test_get_caller_identity_arn_from_a_role (boto_session ):
116
181
sess = Session (boto_session )
117
182
arn = (
@@ -128,6 +193,7 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
128
193
assert actual == expected_role
129
194
130
195
196
+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
131
197
def test_get_caller_identity_arn_from_a_execution_role (boto_session ):
132
198
sess = Session (boto_session )
133
199
arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker"
@@ -143,6 +209,7 @@ def test_get_caller_identity_arn_from_a_execution_role(boto_session):
143
209
)
144
210
145
211
212
+ @patch ("os.path.exists" , side_effect = mock_exists (NOTEBOOK_METADATA_FILE , False ))
146
213
def test_get_caller_identity_arn_from_role_with_path (boto_session ):
147
214
sess = Session (boto_session )
148
215
arn_prefix = "arn:aws:iam::369233609183:role"
0 commit comments