Skip to content

Commit de676a1

Browse files
Davidhwlaurenyu
authored andcommitted
fix: make get_caller_identity_arn get role from DescribeNotebookInstance (#1033)
Add an initial attempt to get the role via DescribeNotebookInstance. If that attempt fails fallback to the current heuristics-based behavior.
1 parent bae66a0 commit de676a1

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

src/sagemaker/session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444
LOGGER = logging.getLogger("sagemaker")
4545

46+
NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
47+
4648
_STATUS_CODE_TABLE = {
4749
"COMPLETED": "Completed",
4850
"INPROGRESS": "InProgress",
@@ -1382,6 +1384,21 @@ def get_caller_identity_arn(self):
13821384
Returns:
13831385
str: The ARN user or role
13841386
"""
1387+
if os.path.exists(NOTEBOOK_METADATA_FILE):
1388+
with open(NOTEBOOK_METADATA_FILE, "rb") as f:
1389+
instance_name = json.loads(f.read())["ResourceName"]
1390+
try:
1391+
instance_desc = self.sagemaker_client.describe_notebook_instance(
1392+
NotebookInstanceName=instance_name
1393+
)
1394+
return instance_desc["RoleArn"]
1395+
except ClientError:
1396+
LOGGER.warning(
1397+
"Couldn't call 'describe_notebook_instance' to get the Role "
1398+
"ARN of the instance %s.",
1399+
instance_name,
1400+
)
1401+
13851402
assumed_role = self.boto_session.client(
13861403
"sts", endpoint_url=sts_regional_endpoint(self.boto_region_name)
13871404
).get_caller_identity()["Arn"]

tests/unit/test_session.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import datetime
1616
import io
1717
import logging
18+
import os
1819

1920
import pytest
2021
import six
@@ -23,7 +24,12 @@
2324

2425
import sagemaker
2526
from sagemaker import s3_input, Session, get_execution_role
26-
from sagemaker.session import _tuning_job_status, _transform_job_status, _train_done
27+
from sagemaker.session import (
28+
_tuning_job_status,
29+
_transform_job_status,
30+
_train_done,
31+
NOTEBOOK_METADATA_FILE,
32+
)
2733
from sagemaker.tuner import WarmStartConfig, WarmStartTypes
2834

2935
STATIC_HPs = {"feature_dim": "784"}
@@ -47,6 +53,18 @@ def boto_session():
4753
return boto_session
4854

4955

56+
def mock_exists(filepath_to_mock, exists_result):
57+
unmocked_exists = os.path.exists
58+
59+
def side_effect(filepath):
60+
if filepath == filepath_to_mock:
61+
return exists_result
62+
else:
63+
return unmocked_exists(filepath)
64+
65+
return Mock(side_effect=side_effect)
66+
67+
5068
def test_get_execution_role():
5169
session = Mock()
5270
session.get_caller_identity_arn.return_value = "arn:aws:iam::369233609183:role/SageMakerRole"
@@ -86,6 +104,51 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
86104
assert "ValueError: The current AWS identity is not a role" in str(error)
87105

88106

107+
@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}'))
108+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
109+
def test_get_caller_identity_arn_from_describe_notebook_instance(boto_session):
110+
sess = Session(boto_session)
111+
expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
112+
sess.sagemaker_client.describe_notebook_instance.return_value = {"RoleArn": expected_role}
113+
114+
actual = sess.get_caller_identity_arn()
115+
116+
assert actual == expected_role
117+
sess.sagemaker_client.describe_notebook_instance.assert_called_once_with(
118+
NotebookInstanceName="SageMakerInstance"
119+
)
120+
121+
122+
@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}'))
123+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True))
124+
def test_get_caller_identity_arn_from_a_role_after_describe_notebook_exception(boto_session):
125+
sess = Session(boto_session)
126+
exception = ClientError(
127+
{"Error": {"Code": "ValidationException", "Message": "RecordNotFound"}}, "Operation"
128+
)
129+
sess.sagemaker_client.describe_notebook_instance.side_effect = exception
130+
131+
arn = (
132+
"arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122"
133+
)
134+
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
135+
"Arn": arn
136+
}
137+
138+
expected_role = "arn:aws:iam::369233609183:role/SageMakerRole"
139+
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": expected_role}}
140+
141+
with patch("logging.Logger.warning") as mock_logger:
142+
actual = sess.get_caller_identity_arn()
143+
mock_logger.assert_called_once()
144+
145+
sess.sagemaker_client.describe_notebook_instance.assert_called_once_with(
146+
NotebookInstanceName="SageMakerInstance"
147+
)
148+
assert actual == expected_role
149+
150+
151+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
89152
def test_get_caller_identity_arn_from_an_user(boto_session):
90153
sess = Session(boto_session)
91154
arn = "arn:aws:iam::369233609183:user/mia"
@@ -98,6 +161,7 @@ def test_get_caller_identity_arn_from_an_user(boto_session):
98161
assert actual == "arn:aws:iam::369233609183:user/mia"
99162

100163

164+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
101165
def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
102166
sess = Session(boto_session)
103167
arn = "arn:aws:iam::369233609183:user/mia"
@@ -112,6 +176,7 @@ def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
112176
mock_logger.assert_called_once()
113177

114178

179+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
115180
def test_get_caller_identity_arn_from_a_role(boto_session):
116181
sess = Session(boto_session)
117182
arn = (
@@ -128,6 +193,7 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
128193
assert actual == expected_role
129194

130195

196+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
131197
def test_get_caller_identity_arn_from_a_execution_role(boto_session):
132198
sess = Session(boto_session)
133199
arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker"
@@ -143,6 +209,7 @@ def test_get_caller_identity_arn_from_a_execution_role(boto_session):
143209
)
144210

145211

212+
@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, False))
146213
def test_get_caller_identity_arn_from_role_with_path(boto_session):
147214
sess = Session(boto_session)
148215
arn_prefix = "arn:aws:iam::369233609183:role"

0 commit comments

Comments
 (0)